mirror of
https://github.com/langgenius/dify.git
synced 2026-01-28 15:56:00 +08:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5bffa1d918 | |||
| c9b0fe47bf | |||
| bcd744b6b7 | |||
| 5e511e01bf | |||
| 52291c645e | |||
| a31466d34e | |||
| d38eac959b | |||
| 9dbb8acd4b | |||
| 46154c6705 | |||
| 54ff03c35d | |||
| 18c710c906 | |||
| 59236b789f | |||
| fd3d43cae1 | |||
| 8eae643911 | |||
| fd9413874a | |||
| 227f9fb77d | |||
| c40ee7e629 | |||
| 841e967d48 | |||
| 9df0dcedae | |||
| 724e053732 | |||
| e409895c02 | |||
| 32d9b6181c | |||
| 2b018fade2 | |||
| e65f9cb17a | |||
| 1367f34398 | |||
| e47f6b879a | |||
| 5809edd74b | |||
| 05bfa11915 | |||
| 435f804c6f | |||
| ae3f1ac0a9 | |||
| 269a465fc4 | |||
| 60e0bbd713 | |||
| 827c97f0d3 | |||
| c8bd76cd66 | |||
| ec5f585df4 | |||
| 1de48f33ca | |||
| 6b41a9593e | |||
| 82267083e8 | |||
| c385961d33 | |||
| 20bab6edec | |||
| 67bed54f32 | |||
| 562a571281 | |||
| fc68c81791 | |||
| 5d9070bc60 | |||
| b11fb0dfd1 | |||
| d1c5c5f160 | |||
| 0b1d1440aa | |||
| 0c420d64b3 | |||
| f9082104ed | |||
| 983834cd52 | |||
| 96d10c8b39 |
11
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
Normal file
11
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
name: "🤝 Help Wanted"
|
||||
description: "Request help from the community"
|
||||
labels:
|
||||
- help-wanted
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of the help you need
|
||||
placeholder: Briefly describe what you need help with.
|
||||
validations:
|
||||
required: true
|
||||
@ -16,6 +16,10 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
|
||||
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
|
||||
Visual data analysis, log review, and annotation for applications
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
|
||||
|
||||
|
||||
## Highlighted Features
|
||||
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
|
||||
- 可视化的对应用进行数据分析,查阅日志或进行标注
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
|
||||
|
||||
## 核心能力
|
||||
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
|
||||
|
||||
@ -50,24 +50,6 @@ S3_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Cookie configuration
|
||||
COOKIE_HTTPONLY=true
|
||||
COOKIE_SAMESITE=None
|
||||
COOKIE_SECURE=true
|
||||
|
||||
# Session configuration
|
||||
SESSION_PERMANENT=true
|
||||
SESSION_USE_SIGNER=true
|
||||
|
||||
## support redis, sqlalchemy
|
||||
SESSION_TYPE=redis
|
||||
|
||||
# session redis configuration
|
||||
SESSION_REDIS_HOST=localhost
|
||||
SESSION_REDIS_PORT=6379
|
||||
SESSION_REDIS_PASSWORD=difyai123456
|
||||
SESSION_REDIS_DB=2
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
|
||||
91
api/app.py
91
api/app.py
@ -1,8 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
@ -12,12 +11,11 @@ import logging
|
||||
import json
|
||||
import threading
|
||||
|
||||
from flask import Flask, request, Response, session
|
||||
import flask_login
|
||||
from flask import Flask, request, Response
|
||||
from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
from services.account_service import TenantService
|
||||
from services.account_service import AccountService
|
||||
from libs.passport import PassportService
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
@ -85,81 +81,33 @@ def initialize_extensions(app):
|
||||
ext_redis.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
def _create_tenant_for_account(account):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
account.current_tenant = tenant
|
||||
|
||||
return tenant
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
"""Load user based on the user_id."""
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint == 'console':
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
if '.' in user_id:
|
||||
tenant_id, account_id = user_id.split('.')
|
||||
else:
|
||||
account_id = user_id
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == workspace_id
|
||||
).first()
|
||||
|
||||
if not tenant_account_join:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id).first()
|
||||
|
||||
if tenant_account_join:
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
else:
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
else:
|
||||
account.current_tenant_id = workspace_id
|
||||
else:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id).first()
|
||||
if tenant_account_join:
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
else:
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = current_time
|
||||
db.session.commit()
|
||||
|
||||
# Log in the user with the updated user_id
|
||||
flask_login.login_user(account, remember=True)
|
||||
|
||||
return account
|
||||
return AccountService.load_user(user_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
@ -216,6 +164,7 @@ if app.config['TESTING']:
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie('remember_token', '', expires=0)
|
||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
||||
return response
|
||||
|
||||
213
api/commands.py
213
api/commands.py
@ -3,11 +3,13 @@ import json
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
from tqdm import tqdm
|
||||
from flask import current_app
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -23,7 +25,7 @@ from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
import secrets
|
||||
import base64
|
||||
@ -239,7 +241,13 @@ def clean_unused_dataset_indexes():
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete()
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
@ -346,7 +354,8 @@ def create_qdrant_indexes():
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
@ -364,7 +373,8 @@ def create_qdrant_indexes():
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
@ -373,7 +383,8 @@ def create_qdrant_indexes():
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
@ -414,7 +425,8 @@ def update_qdrant_indexes():
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
@ -435,11 +447,104 @@ def update_qdrant_indexes():
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = []
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
datasets_result = datasets.items
|
||||
page += 1
|
||||
for i in range(0, len(datasets_result), 5):
|
||||
threads = []
|
||||
sub_datasets = datasets_result[i:i + 5]
|
||||
for dataset in sub_datasets:
|
||||
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'normalization_count': normalization_count
|
||||
})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
# index.delete_by_group_id(dataset.id)
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
normalization_count.append(1)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def update_app_model_configs(batch_size):
|
||||
@ -473,7 +578,7 @@ def update_app_model_configs(batch_size):
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.count()
|
||||
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
@ -485,14 +590,14 @@ def update_app_model_configs(batch_size):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
@ -512,7 +617,7 @@ def update_app_model_configs(batch_size):
|
||||
app_data = db.session.query(App) \
|
||||
.filter(App.id == data.app_id) \
|
||||
.one()
|
||||
|
||||
|
||||
account_data = db.session.query(Account) \
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
@ -534,13 +639,85 @@ def update_app_model_configs(batch_size):
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
@click.command('migrate_default_input_to_dataset_query_variable')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
|
||||
click.secho("Starting...", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
config = AppModelConfig.to_dict(data)
|
||||
|
||||
tools = config["agent_mode"]["tools"]
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
if not dataset_exists:
|
||||
continue
|
||||
|
||||
user_input_form = config.get("user_input_form", [])
|
||||
for form in user_input_form:
|
||||
paragraph = form.get('paragraph')
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -551,4 +728,6 @@ def register_commands(app):
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
|
||||
@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
|
||||
dotenv.load_dotenv()
|
||||
|
||||
DEFAULTS = {
|
||||
'COOKIE_HTTPONLY': 'True',
|
||||
'COOKIE_SECURE': 'True',
|
||||
'COOKIE_SAMESITE': 'None',
|
||||
'DB_USERNAME': 'postgres',
|
||||
'DB_PASSWORD': '',
|
||||
'DB_HOST': 'localhost',
|
||||
@ -22,10 +19,6 @@ DEFAULTS = {
|
||||
'REDIS_PORT': '6379',
|
||||
'REDIS_DB': '0',
|
||||
'REDIS_USE_SSL': 'False',
|
||||
'SESSION_REDIS_HOST': 'localhost',
|
||||
'SESSION_REDIS_PORT': '6379',
|
||||
'SESSION_REDIS_DB': '2',
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
@ -36,9 +29,6 @@ DEFAULTS = {
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
'SESSION_TYPE': 'sqlalchemy',
|
||||
'SESSION_PERMANENT': 'True',
|
||||
'SESSION_USE_SIGNER': 'True',
|
||||
'DEPLOY_ENV': 'PRODUCTION',
|
||||
'SQLALCHEMY_POOL_SIZE': 30,
|
||||
'SQLALCHEMY_POOL_RECYCLE': 3600,
|
||||
@ -61,6 +51,8 @@ DEFAULTS = {
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
@ -100,7 +92,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.22"
|
||||
self.CURRENT_VERSION = "0.3.24"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -113,20 +105,6 @@ class Config:
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# cookie settings
|
||||
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
||||
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
||||
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
||||
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
||||
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
||||
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
||||
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
|
||||
|
||||
# session settings, only support sqlalchemy, redis
|
||||
self.SESSION_TYPE = get_env('SESSION_TYPE')
|
||||
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
|
||||
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
|
||||
|
||||
# redis settings
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
@ -135,14 +113,6 @@ class Config:
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# session redis settings
|
||||
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
|
||||
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
|
||||
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
|
||||
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
|
||||
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
|
||||
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
|
||||
|
||||
# storage settings
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
@ -230,6 +200,9 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ model_templates = {
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'text-davinci-003',
|
||||
'model_id': 'gpt-3.5-turbo-instruct',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
@ -30,7 +30,7 @@ model_templates = {
|
||||
},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@ -104,7 +104,7 @@ demo_model_templates = {
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='text-davinci-003',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
@ -140,7 +140,7 @@ demo_model_templates = {
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@ -222,7 +222,7 @@ demo_model_templates = {
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='text-davinci-003',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
@ -258,7 +258,7 @@ demo_model_templates = {
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "text-davinci-003",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
|
||||
@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
|
||||
@ -19,40 +19,13 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||
app_detail_fields_with_site
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@ -62,35 +35,6 @@ def _get_app(app_id, tenant_id):
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
prompt_config_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
model_config_partial_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -162,7 +106,8 @@ class AppListApi(Resource):
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=model_config_dict
|
||||
config=model_config_dict,
|
||||
mode=args['mode']
|
||||
)
|
||||
|
||||
app = App(
|
||||
@ -236,18 +181,6 @@ class AppListApi(Resource):
|
||||
|
||||
|
||||
class AppTemplateApi(Resource):
|
||||
template_fields = {
|
||||
'name': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'mode': fields.String,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
'data': fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -266,38 +199,6 @@ class AppTemplateApi(Resource):
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
site_fields = {
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean,
|
||||
'app_base_url': fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'site': fields.Nested(site_fields),
|
||||
'api_base_url': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -13,107 +13,14 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
|
||||
conversation_message_detail_fields, conversation_with_summary_pagination_fields
|
||||
from libs.helper import TimestampField, datetime_string, uuid_value
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAnnotation, Conversation
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
'like': fields.Integer,
|
||||
'dislike': fields.Integer
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'model': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw,
|
||||
}
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]['text'] if value else ''
|
||||
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': MessageTextField,
|
||||
'answer': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(conversation_detail_fields)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
|
||||
@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
|
||||
|
||||
|
||||
class ChatConversationDetailApi(Resource):
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
||||
@ -31,7 +31,8 @@ class ModelConfigResource(Resource):
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
|
||||
@ -8,26 +8,11 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.helper import supported_language
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
|
||||
app_site_fields = {
|
||||
'app_id': fields.String,
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -45,15 +45,34 @@ class OAuthDataSource(Resource):
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
return { 'data': '' }
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
return { 'data': auth_url }, 200
|
||||
|
||||
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -69,12 +88,7 @@ class OAuthDataSourceCallback(Resource):
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@ -101,4 +115,5 @@ class OAuthDataSourceSync(Resource):
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
|
||||
@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
@ -37,12 +36,12 @@ class LoginApi(Resource):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
flask_login.login_user(account, remember=args['remember_me'])
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
# todo: return the user info
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return {'result': 'success'}
|
||||
return {'result': 'success', 'data': token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
|
||||
@ -2,9 +2,8 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask import request, redirect, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
|
||||
@ -75,12 +74,11 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# login user
|
||||
session.clear()
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
|
||||
from libs.helper import TimestampField
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -131,28 +101,6 @@ class DataSourceApi(Resource):
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
import flask_restful
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.apikey import api_key_list, api_key_fields
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from libs.helper import TimestampField
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from models.model import UploadFile, ApiToken
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'provider': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'app_count': fields.Integer,
|
||||
'document_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
@ -82,7 +56,8 @@ class DatasetListApi(Resource):
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
@ -157,7 +132,8 @@ class DatasetApi(Resource):
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
# get valid model list
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
app_detail_kernel_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||
'total': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -400,16 +347,101 @@ class DatasetIndexingStatusApi(Resource):
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = 'dataset-'
|
||||
resource_type = 'dataset'
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
all()
|
||||
return {"items": keys}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
count()
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code='max_keys_exceeded'
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
|
||||
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = 'dataset'
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id):
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id). \
|
||||
first()
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
|
||||
@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||
LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||
dataset_and_document_fields, document_status_fields
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetProcessRule, Dataset
|
||||
@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
dataset_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
|
||||
class DatasetInitApi(Resource):
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
return marshal(document, self.document_status_fields)
|
||||
return marshal(document, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
|
||||
@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
import pandas as pd
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField
|
||||
}
|
||||
|
||||
segment_list_response = {
|
||||
'data': fields.List(fields.Nested(segment_fields)),
|
||||
'has_more': fields.Boolean,
|
||||
'limit': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
|
||||
@ -1,28 +1,19 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import tempfile
|
||||
import chardet
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
import services
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from fields.file_fields import upload_config_fields, file_fields
|
||||
|
||||
from services.file_service import FileService
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -48,16 +35,6 @@ class FileApi(Resource):
|
||||
'batch_count_limit': batch_count_limit
|
||||
}, 200
|
||||
|
||||
file_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'size': fields.Integer,
|
||||
'extension': fields.String,
|
||||
'mime_type': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -73,45 +50,13 @@ class FileApi(Resource):
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file_content = file.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if file_size > file_size_limit:
|
||||
message = "({file_size} > {file_size_limit})"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
try:
|
||||
upload_file = FileService.upload_file(file)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
text = FileService.get_file_preview(file_id)
|
||||
return {'content': text}
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
import services
|
||||
@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from libs.helper import TimestampField
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'name': fields.String,
|
||||
'doc_type': fields.String,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'document': fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
'segment': fields.Nested(segment_fields),
|
||||
'score': fields.Float,
|
||||
'tsne_position': fields.Raw
|
||||
}
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
|
||||
|
||||
@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
|
||||
@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
|
||||
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'chat':
|
||||
|
||||
@ -11,32 +11,11 @@ from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String
|
||||
}
|
||||
|
||||
installed_app_fields = {
|
||||
'id': fields.String,
|
||||
'app': fields.Nested(app_fields),
|
||||
'app_owner_tenant_id': fields.String,
|
||||
'is_pinned': fields.Boolean,
|
||||
'last_used_at': TimestampField,
|
||||
'editable': fields.Boolean,
|
||||
'uninstallable': fields.Boolean,
|
||||
}
|
||||
|
||||
installed_app_list_fields = {
|
||||
'installed_apps': fields.List(fields.Nested(installed_app_fields))
|
||||
}
|
||||
|
||||
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@ -26,45 +27,6 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
@ -58,9 +57,6 @@ class SetupApi(Resource):
|
||||
)
|
||||
|
||||
setup()
|
||||
|
||||
# Login
|
||||
flask_login.login_user(account)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
@ -33,7 +33,6 @@ class UniversalChatApi(UniversalChatResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
|
||||
@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
|
||||
conversation_with_model_config_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(conversation_with_model_config_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default
|
||||
'default': v.default,
|
||||
'precision': v.precision
|
||||
}
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
@ -285,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
token=args['token']
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
@ -300,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
||||
|
||||
@ -9,4 +9,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from .app import completion, app, conversation, message, audio
|
||||
|
||||
from .dataset import document
|
||||
from .dataset import document, segment, dataset
|
||||
|
||||
@ -8,25 +8,11 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
import services
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
|
||||
@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
84
api/controllers/service_api/dataset/dataset.py
Normal file
84
api/controllers/service_api/dataset/dataset.py
Normal file
@ -0,0 +1,84 @@
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.login.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for get datasets."""
|
||||
|
||||
def get(self, tenant_id):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
"""Resource for datasets."""
|
||||
|
||||
def post(self, tenant_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
api.add_resource(DatasetApi, '/datasets')
|
||||
|
||||
@ -1,114 +1,291 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import reqparse
|
||||
from flask import current_app, request
|
||||
from flask_restful import reqparse, marshal
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||
DatasetNotInitedError
|
||||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.login.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
def post(self, dataset):
|
||||
"""Create document."""
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('doc_type', type=str, location='json')
|
||||
parser.add_argument('doc_metadata', type=dict, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
raise DatasetNotInitedError("Dataset indexing technique must be set.")
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
doc_type = args.get('doc_type')
|
||||
doc_metadata = args.get('doc_metadata')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise ValueError('Invalid doc_type.')
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, args.get('text'))
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=dataset.tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=args.get('name') + '.txt',
|
||||
size=len(args.get('text')),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
created_by=dataset.created_by,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=True,
|
||||
used_by=dataset.created_by,
|
||||
used_at=datetime.datetime.utcnow()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
document_data = {
|
||||
'data_source': {
|
||||
'type': 'upload_file',
|
||||
'info': [
|
||||
{
|
||||
'upload_file_id': upload_file.id
|
||||
}
|
||||
]
|
||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'data_source_type': 'upload_file',
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=document_data,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
document_data=args,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'id': document.id}
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentApi(DatasetApiResource):
|
||||
def delete(self, dataset, document_id):
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
if args['text']:
|
||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'data_source_type': 'upload_file',
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
args['original_document_id'] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
if 'data' in request.form:
|
||||
args = json.loads(request.form['data'])
|
||||
if 'doc_form' not in args:
|
||||
args['doc_form'] = 'text_model'
|
||||
if 'doc_language' not in args:
|
||||
args['doc_language'] = 'English'
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
# save file info
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
if 'data' in request.form:
|
||||
args = json.loads(request.form['data'])
|
||||
if 'doc_form' not in args:
|
||||
args['doc_form'] = 'text_model'
|
||||
if 'doc_language' not in args:
|
||||
args['doc_language'] = 'English'
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if 'file' in request.files:
|
||||
# save file info
|
||||
file = request.files['file']
|
||||
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
args['original_document_id'] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(DocumentListApi, '/documents')
|
||||
api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
query = Document.query.filter_by(
|
||||
dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if search:
|
||||
search = f'%{search}%'
|
||||
query = query.filter(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at))
|
||||
|
||||
paginated_documents = query.paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
'data': marshal(documents, document_fields),
|
||||
'has_more': len(documents) == limit,
|
||||
'limit': limit,
|
||||
'total': paginated_documents.total,
|
||||
'page': page
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
tenant_id = str(tenant_id)
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# get documents
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
raise NotFound('Documents not found.')
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
|
||||
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
|
||||
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
|
||||
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
|
||||
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
||||
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
|
||||
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
|
||||
|
||||
@ -1,20 +1,73 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = 'high_quality_dataset_only'
|
||||
description = "Current operation only supports 'high-quality' datasets."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_initialized'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = 'archived_document_immutable'
|
||||
description = "Cannot operate when document was archived."
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = 'dataset_name_duplicate'
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = 'invalid_action'
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = 'document_already_finished'
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = 'document_indexing'
|
||||
description = "Cannot operate document during indexing."
|
||||
code = 403
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_inited'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 403
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = 'invalid_metadata'
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
59
api/controllers/service_api/dataset/segment.py
Normal file
59
api/controllers/service_api/dataset/segment.py
Normal file
@ -0,0 +1,59 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
@ -2,11 +2,14 @@
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask import request, current_app
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from core.login.login import _get_user
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('dataset')
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound()
|
||||
|
||||
return view(dataset, *args, **kwargs)
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account is not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant is not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
|
||||
@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationListApi(WebApiResource):
|
||||
|
||||
@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
|
||||
|
||||
class ConversationRenameApi(WebApiResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.helper import moderation
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@ -116,6 +118,18 @@ class AgentExecutor:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.configuration.model_instance.model_provider,
|
||||
query
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
return AgentExecuteResult(
|
||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
@ -128,7 +142,9 @@ class AgentExecutor:
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except Exception:
|
||||
except LLMError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_instant = model_instant
|
||||
self.model_instance = model_instance
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([message.content for message in messages[0]]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instant, self._current_loop
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
|
||||
@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
latency: float = 0.0
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
self.start_at = time.perf_counter()
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
|
||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Do nothing."""
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
|
||||
@ -1,15 +1,33 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation import openai_moderation
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceRule(BaseModel):
|
||||
class Type(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
KEYWORDS = "keywords"
|
||||
|
||||
type: Type
|
||||
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
||||
extra_params: dict = {}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
sensitive_words: List[str] = []
|
||||
canned_response: str = None
|
||||
model_instance: BaseLLM
|
||||
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _check_sensitive_word(self, text: str) -> str:
|
||||
for word in self.sensitive_words:
|
||||
def _check_sensitive_word(self, text: str) -> bool:
|
||||
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
||||
if word in text:
|
||||
return self.canned_response
|
||||
return text
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_moderation(self, text: str) -> bool:
|
||||
moderation_model_instance = ModelFactory.get_moderation_model(
|
||||
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
||||
model_provider_name='openai',
|
||||
model_name=openai_moderation.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
try:
|
||||
return moderation_model_instance.run(text=text)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
output = self._check_sensitive_word(text)
|
||||
return {self.output_key: output}
|
||||
|
||||
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
||||
result = self._check_sensitive_word(text)
|
||||
else:
|
||||
result = self._check_moderation(text)
|
||||
|
||||
if not result:
|
||||
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
|
||||
|
||||
return {self.output_key: text}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
@ -79,28 +77,55 @@ class Completion:
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
# run the final llm
|
||||
try:
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
|
||||
final_model_instance, [chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
try:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
except SensitiveWordAvoidanceError as ex:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=ex.message
|
||||
)
|
||||
return
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if query_for_agent and agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query_for_agent)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query_for_agent)
|
||||
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# run the final llm
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
@ -109,7 +134,8 @@ class Completion:
|
||||
inputs=inputs,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
return
|
||||
@ -118,20 +144,21 @@ class Completion:
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import decimal
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
@ -23,6 +23,8 @@ class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
|
||||
self.app = app
|
||||
@ -61,6 +63,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
def init(self):
|
||||
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
@ -109,7 +112,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(self.conversation)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app_model_config.app_id,
|
||||
@ -137,7 +140,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(self.message)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
@ -165,7 +168,7 @@ class ConversationMessageTask:
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
self.message.total_price = total_price
|
||||
|
||||
db.session.commit()
|
||||
@ -188,12 +191,13 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
||||
message_chain.output = json.dumps(chain_result.completion)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
@ -214,24 +218,24 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
@ -245,8 +249,8 @@ class ConversationMessageTask:
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = agent_model_instant.get_currency()
|
||||
db.session.flush()
|
||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
@ -259,6 +263,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_finish(self, resource: List):
|
||||
if resource and len(resource) > 0:
|
||||
@ -282,7 +287,7 @@ class ConversationMessageTask:
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def message_end(self):
|
||||
|
||||
@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||
|
||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
|
||||
34
api/core/helper/moderation.py
Normal file
34
api/core/helper/moderation.py
Normal file
@ -0,0 +1,34 @@
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and model_provider.provider_name in hosted_config.moderation.providers:
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
|
||||
for text_chunk in chunks:
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -16,6 +16,10 @@ class BaseIndex(ABC):
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
raise NotImplementedError
|
||||
@ -28,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
@ -214,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
if pre_segment_data['keywords']:
|
||||
segment.keywords = pre_segment_data['keywords']
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||
pre_segment_data['keywords'])
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@ -110,6 +110,14 @@ class BaseVectorIndex(BaseIndex):
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@ -243,3 +251,53 @@ class BaseVectorIndex(BaseIndex):
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=collection_name,
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
||||
@ -28,6 +28,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@ -84,6 +85,7 @@ class Qdrant(VectorStore):
|
||||
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR_NAME = None
|
||||
|
||||
def __init__(
|
||||
@ -93,9 +95,12 @@ class Qdrant(VectorStore):
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
distance_strategy: str = "COSINE",
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
is_new_collection: bool = False
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
try:
|
||||
@ -129,7 +134,10 @@ class Qdrant(VectorStore):
|
||||
self.collection_name = collection_name
|
||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||
self.group_payload_key = group_payload_key or self.GROUP_KEY
|
||||
self.vector_name = vector_name or self.VECTOR_NAME
|
||||
self.group_id = group_id
|
||||
self.is_new_collection= is_new_collection
|
||||
|
||||
if embedding_function is not None:
|
||||
warnings.warn(
|
||||
@ -170,6 +178,8 @@ class Qdrant(VectorStore):
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 64
|
||||
group_id:
|
||||
collection group
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
@ -182,7 +192,11 @@ class Qdrant(VectorStore):
|
||||
collection_name=self.collection_name, points=points, **kwargs
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
# if is new collection, create payload index on group_id
|
||||
if self.is_new_collection:
|
||||
self.client.create_payload_index(self.collection_name, self.group_payload_key,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
return added_ids
|
||||
|
||||
@sync_call_fallback
|
||||
@ -970,6 +984,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
batch_size: int = 64,
|
||||
shard_number: Optional[int] = None,
|
||||
@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key:
|
||||
A payload key used to store the metadata of the document.
|
||||
Default: "metadata"
|
||||
group_payload_key:
|
||||
A payload key used to store the content of the document.
|
||||
Default: "group_id"
|
||||
group_id:
|
||||
collection group id
|
||||
vector_name:
|
||||
Name of the vector to be used internally in Qdrant.
|
||||
Default: None
|
||||
@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
|
||||
distance_func,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
group_payload_key,
|
||||
group_id,
|
||||
vector_name,
|
||||
shard_number,
|
||||
replication_factor,
|
||||
@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
is_new_collection = False
|
||||
client = qdrant_client.QdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
@ -1364,70 +1390,12 @@ class Qdrant(VectorStore):
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Skip any validation in case of forced collection recreate.
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
all_collection_name = []
|
||||
collections_response = client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
@ -1454,6 +1422,68 @@ class Qdrant(VectorStore):
|
||||
init_from=init_from,
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
is_new_collection = True
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
@ -1462,6 +1492,9 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
group_id=group_id,
|
||||
group_payload_key=group_payload_key,
|
||||
is_new_collection=is_new_collection
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@ -1516,6 +1549,8 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
) -> List[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
@ -1529,6 +1564,7 @@ class Qdrant(VectorStore):
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
|
||||
@ -1578,7 +1614,7 @@ class Qdrant(VectorStore):
|
||||
else:
|
||||
out.append(
|
||||
rest.FieldCondition(
|
||||
key=f"{self.metadata_payload_key}.{key}",
|
||||
key=key,
|
||||
match=rest.MatchValue(value=value),
|
||||
)
|
||||
)
|
||||
@ -1654,6 +1690,7 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@ -1684,6 +1721,8 @@ class Qdrant(VectorStore):
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
self.group_id,
|
||||
self.group_payload_key
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http.models import HnswConfigDiff
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from models.dataset import Dataset
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
root_path: Optional[str]
|
||||
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
return dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
return class_prefix
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='page_content'
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@ -132,6 +154,35 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=group_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
|
||||
@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import current_app
|
||||
from flask import g
|
||||
from flask import has_request_context
|
||||
from flask import request
|
||||
from flask import request, session
|
||||
from flask_login import user_logged_in
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
@ -180,7 +181,7 @@ class ModelFactory:
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseProviderModel]:
|
||||
model_name: str) -> Optional[BaseModeration]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
|
||||
@ -45,6 +45,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'zhipuai':
|
||||
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||
return ZhipuAIProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class HuggingfaceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = HuggingfaceHubEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")
|
||||
@ -0,0 +1,22 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
|
||||
|
||||
|
||||
class ZhipuAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = ZhipuAIEmbeddings(
|
||||
model=name,
|
||||
**credentials,
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")
|
||||
@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
|
||||
max: Optional[T] = None
|
||||
default: Optional[T] = None
|
||||
alias: Optional[str] = None
|
||||
precision: Optional[int] = None
|
||||
|
||||
|
||||
class ModelKwargsRules(BaseModel):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
import decimal
|
||||
@ -10,6 +11,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
@ -19,6 +21,8 @@ from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -116,9 +120,20 @@ class BaseLLM(BaseProviderModel):
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.model_provider,
|
||||
"\n".join([message.content for message in messages])
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
kwargs['fake_response'] = "I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest."
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = self.callbacks
|
||||
else:
|
||||
|
||||
@ -17,6 +17,7 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
COMPLETION_MODELS = [
|
||||
'gpt-3.5-turbo-instruct', # 4,096 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
]
|
||||
|
||||
@ -31,6 +32,7 @@ MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ class WenxinModel(BaseLLM):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
|
||||
61
api/core/model_providers/models/llm/zhipuai_model.py
Normal file
61
api/core/model_providers/models/llm/zhipuai_model.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||
|
||||
|
||||
class ZhipuAIModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ZhipuAIChatLLM(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
29
api/core/model_providers/models/moderation/base.py
Normal file
29
api/core/model_providers/models/moderation/base.py
Normal file
@ -0,0 +1,29 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseModeration(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.MODERATION
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def run(self, text: str) -> bool:
|
||||
try:
|
||||
return self._run(text)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, text: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
||||
@ -4,29 +4,39 @@ import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
DEFAULT_AUDIO_MODEL = 'whisper-1'
|
||||
DEFAULT_MODEL = 'whisper-1'
|
||||
|
||||
|
||||
class OpenAIModeration(BaseProviderModel):
|
||||
type: ModelType = ModelType.MODERATION
|
||||
class OpenAIModeration(BaseModeration):
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
super().__init__(model_provider, openai.Moderation)
|
||||
super().__init__(model_provider, openai.Moderation, name)
|
||||
|
||||
def run(self, text):
|
||||
def _run(self, text: str) -> bool:
|
||||
credentials = self.model_provider.get_model_credentials(
|
||||
model_name=DEFAULT_AUDIO_MODEL,
|
||||
model_name=self.name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
try:
|
||||
return self._client.create(input=text, api_key=credentials['openai_api_key'])
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
|
||||
for text_chunk in chunks:
|
||||
moderation_result = self._client.create(input=text_chunk,
|
||||
api_key=credentials['openai_api_key'])
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
|
||||
@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
|
||||
model_credentials['base_model_name'],
|
||||
4097
|
||||
), default=16),
|
||||
), default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
|
||||
hosted_model_providers = HostedModelProviders()
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
providers: list[str] = []
|
||||
|
||||
|
||||
class HostedConfig(BaseModel):
|
||||
moderation = HostedModerationConfig()
|
||||
|
||||
|
||||
hosted_config = HostedConfig()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
@ -78,3 +90,9 @@ def init_app(app: Flask):
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
hosted_config.moderation = HostedModerationConfig(
|
||||
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
|
||||
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Type
|
||||
import requests
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
|
||||
from models.provider import ProviderType
|
||||
|
||||
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
|
||||
|
||||
|
||||
class HuggingfaceHubProvider(BaseModelProvider):
|
||||
@property
|
||||
@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = HuggingfaceHubModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = HuggingfaceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -47,11 +54,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if model_type != ModelType.TEXT_GENERATION:
|
||||
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'huggingfacehub_api_type' not in credentials \
|
||||
@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
||||
'text-generation, summarization.')
|
||||
'text-generation, feature-extraction.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
if credentials['task_type'] == 'feature-extraction':
|
||||
cls.check_embedding_valid(credentials, model_name)
|
||||
else:
|
||||
cls.check_llm_valid(credentials)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
else:
|
||||
@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction")
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
|
||||
@classmethod
|
||||
def check_llm_valid(cls, credentials: dict):
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
|
||||
@classmethod
|
||||
def check_embedding_valid(cls, credentials: dict, model_name: str):
|
||||
|
||||
cls.check_endpoint_url_model_repository_name(credentials, model_name)
|
||||
|
||||
embedding_model = HuggingfaceHubEmbeddings(
|
||||
model=model_name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
embedding_model.embed_query("ping")
|
||||
|
||||
@classmethod
|
||||
def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response =requests.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError('User Name or Organization Name is invalid.')
|
||||
|
||||
model_repository_name = ''
|
||||
|
||||
for item in response.json().get("items", []):
|
||||
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
|
||||
model_repository_name = item.get("model", {}).get("repository")
|
||||
break
|
||||
|
||||
if model_repository_name != model_name:
|
||||
raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
|
||||
@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=0.7),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
max_tokens=KwargRule[int](min=10, max=4097, default=16),
|
||||
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95),
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -40,6 +40,10 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-instruct',
|
||||
'name': 'GPT-3.5-Turbo-Instruct',
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
@ -128,16 +132,17 @@ class OpenAIProvider(BaseModelProvider):
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
|
||||
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
|
||||
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
|
||||
default=float(value.get('default')) if value.get('default') is not None else None,
|
||||
precision = 2
|
||||
)
|
||||
if key == 'temperature':
|
||||
model_kwargs_rules.temperature = kwarg_rule
|
||||
@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
|
||||
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
|
||||
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
|
||||
default=int(value.get('default')) if value.get('default') is not None else 500,
|
||||
precision = 0
|
||||
)
|
||||
|
||||
return model_kwargs_rules
|
||||
|
||||
@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5),
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -61,13 +61,18 @@ class WenxinProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'ernie-bot': 4800,
|
||||
'ernie-bot-turbo': 11200,
|
||||
}
|
||||
|
||||
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@ -52,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
|
||||
)
|
||||
|
||||
|
||||
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm("ping")
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
embedding = XinferenceEmbeddings(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
|
||||
176
api/core/model_providers/providers/zhipuai_provider.py
Normal file
176
api/core/model_providers/providers/zhipuai_provider.py
Normal file
@ -0,0 +1,176 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class ZhipuAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'zhipuai'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm_pro',
|
||||
'name': 'chatglm_pro',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_std',
|
||||
'name': 'chatglm_std',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite',
|
||||
'name': 'chatglm_lite',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite_32k',
|
||||
'name': 'chatglm_lite_32k',
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text_embedding',
|
||||
'name': 'text_embedding',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ZhipuAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = ZhipuAIEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key']
|
||||
}
|
||||
|
||||
llm = ZhipuAIChatLLM(
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
@ -6,6 +6,7 @@
|
||||
"tongyi",
|
||||
"spark",
|
||||
"wenxin",
|
||||
"zhipuai",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
|
||||
@ -30,6 +30,12 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
|
||||
44
api/core/model_providers/rules/zhipuai.json
Normal file
44
api/core/model_providers/rules/zhipuai.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"free"
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"chatglm_pro": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.01",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_std": {
|
||||
"prompt": "0.005",
|
||||
"completion": "0.005",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_lite": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"chatglm_lite_32k": {
|
||||
"prompt": "0.0004",
|
||||
"completion": "0.0004",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"text_embedding": {
|
||||
"completion": "0",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
@ -63,7 +65,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_instant=agent_model_instance,
|
||||
model_instance=agent_model_instance,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
|
||||
@ -123,23 +125,45 @@ class OrchestratorRuleParser:
|
||||
|
||||
return chain
|
||||
|
||||
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
||||
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
|
||||
-> Optional[SensitiveWordAvoidanceChain]:
|
||||
"""
|
||||
Convert app sensitive word avoidance config to chain
|
||||
|
||||
:param model_instance: model instance
|
||||
:param callbacks: callbacks for the chain
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not self.app_model_config.sensitive_word_avoidance_dict:
|
||||
return None
|
||||
sensitive_word_avoidance_rule = None
|
||||
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
||||
if self.app_model_config.sensitive_word_avoidance_dict:
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
if sensitive_word_avoidance_config.get("enabled", False):
|
||||
if sensitive_word_avoidance_config.get('type') == 'moderation':
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.MODERATION,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
)
|
||||
else:
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_words:
|
||||
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||
if sensitive_word_avoidance_config.get("canned_response")
|
||||
else 'Your content violates our usage policy. Please revise and try again.',
|
||||
extra_params={
|
||||
'sensitive_words': sensitive_words.split(','),
|
||||
}
|
||||
)
|
||||
|
||||
if sensitive_word_avoidance_rule:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
||||
model_instance=model_instance,
|
||||
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=callbacks,
|
||||
**kwargs
|
||||
|
||||
74
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
vendored
Normal file
74
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
vendored
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
HOSTED_INFERENCE_API = 'hosted_inference_api'
|
||||
INFERENCE_ENDPOINTS = 'inference_endpoints'
|
||||
|
||||
|
||||
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
|
||||
client: Any
|
||||
model: str
|
||||
|
||||
huggingface_namespace: Optional[str] = None
|
||||
task_type: Optional[str] = None
|
||||
huggingfacehub_api_type: Optional[str] = None
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
huggingfacehub_endpoint_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values['huggingfacehub_api_token'] = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values['client'] = InferenceClient(token=values['huggingfacehub_api_token'])
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
model = ''
|
||||
|
||||
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
|
||||
model = self.model
|
||||
else:
|
||||
model = self.huggingfacehub_endpoint_url
|
||||
|
||||
output = self.client.post(
|
||||
json={
|
||||
"inputs": texts,
|
||||
"options": {
|
||||
"wait_for_model": False,
|
||||
"use_cache": False
|
||||
}
|
||||
}, model=model)
|
||||
|
||||
embeddings = json.loads(output.decode())
|
||||
return self.mean_pooling(embeddings)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
|
||||
# Returned values are a list of floats, or a list of list of floats
|
||||
# (depending on if you sent a string or a list of string,
|
||||
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
|
||||
# This should be explained on the model's README.)
|
||||
def mean_pooling(self, embeddings: List) -> List[float]:
|
||||
# If automatic reduction by giving model, no need to mean_pooling.
|
||||
# For example one: List[List[float]]
|
||||
if not isinstance(embeddings[0][0], list):
|
||||
return embeddings
|
||||
|
||||
# For example two: List[List[List[float]]], need to mean_pooling.
|
||||
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
|
||||
return sentence_embeddings
|
||||
64
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
vendored
Normal file
64
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
"""Wrapper around ZhipuAI embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
|
||||
|
||||
|
||||
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around ZhipuAI embedding models.
|
||||
1024 dimensions.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = self.client.invoke(model=self.model, prompt=text)
|
||||
data = response["data"]
|
||||
embeddings.append(data.get('embedding'))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@ -16,7 +16,7 @@ class HuggingFaceHubLLM(HuggingFaceHub):
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
Only supports `text-generation`, `text2text-generation` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@ -14,6 +14,9 @@ class EnhanceOpenAI(OpenAI):
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
def __new__(cls, **data: Any): # type: ignore
|
||||
return super(EnhanceOpenAI, cls).__new__(cls)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
Normal file
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
Normal file
@ -0,0 +1,315 @@
|
||||
"""Wrapper around ZhipuAI APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import posixpath
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator, Sequence,
|
||||
)
|
||||
|
||||
import zhipuai
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import Extra, root_validator, BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils import jwt_token
|
||||
from zhipuai.utils.http_client import post, stream
|
||||
from zhipuai.utils.sse_client import SSEClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZhipuModelAPI(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
api_timeout_seconds = 60
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
if not response['success']:
|
||||
raise ValueError(
|
||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||
)
|
||||
return response
|
||||
|
||||
def sse_invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
return SSEClient(data)
|
||||
|
||||
def _build_api_url(self, kwargs, *path):
|
||||
if kwargs:
|
||||
if "model" not in kwargs:
|
||||
raise Exception("model param missed")
|
||||
model = kwargs.pop("model")
|
||||
else:
|
||||
model = "-"
|
||||
|
||||
return posixpath.join(self.base_url, model, *path)
|
||||
|
||||
def _generate_token(self):
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"api_key not provided, you could provide it."
|
||||
)
|
||||
|
||||
try:
|
||||
return jwt_token.generate_token(self.api_key)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Your api_key is invalid, please check it."
|
||||
)
|
||||
|
||||
|
||||
class ZhipuAIChatLLM(BaseChatModel):
|
||||
"""Wrapper around ZhipuAI large language models.
|
||||
To use, you should pass the api_key as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from core.third_party.langchain.llms.zhipuai import ZhipuAI
|
||||
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "chatglm_lite"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.95
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.7
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
|
||||
if 'test' in values['base_url']:
|
||||
values['model'] = 'chatglm_130b_test'
|
||||
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "zhipuai"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
continue
|
||||
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
response = self.client.invoke(**request)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
|
||||
for event in self.client.sse_invoke(incremental=True, **request).events():
|
||||
if event.event == "add":
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(event.data)
|
||||
elif event.event == "error" or event.event == "interrupted":
|
||||
raise ValueError(
|
||||
f"{event.data}"
|
||||
)
|
||||
elif event.event == "finish":
|
||||
meta = json.loads(event.meta)
|
||||
token_usage = meta['usage']
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=event.data),
|
||||
generation_info=dict({'token_usage': token_usage})
|
||||
)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
data = response["data"]
|
||||
generations = []
|
||||
for res in data["choices"]:
|
||||
message = self._convert_dict_to_message(res)
|
||||
gen = ChatGeneration(
|
||||
message=message
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = data.get("usage")
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
# def get_token_ids(self, text: str) -> List[int]:
|
||||
# """Return the ordered ids of the tokens in a text.
|
||||
#
|
||||
# Args:
|
||||
# text: The string input to tokenize.
|
||||
#
|
||||
# Returns:
|
||||
# A list of ids corresponding to the tokens in the text, in order they occur
|
||||
# in the text.
|
||||
# """
|
||||
# from core.third_party.transformers.Token import ChatGLMTokenizer
|
||||
#
|
||||
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
|
||||
# return tokenizer.encode(text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return_resource: str
|
||||
retriever_from: str
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
'k': self.k,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
||||
@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def delete_group(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
|
||||
@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
|
||||
@dataset_was_deleted.connect
|
||||
def handle(sender, **kwargs):
|
||||
dataset = sender
|
||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
|
||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
|
||||
dataset.index_struct, dataset.collection_binding_id)
|
||||
|
||||
@ -1,174 +0,0 @@
|
||||
import redis
|
||||
from redis.connection import SSLConnection, Connection
|
||||
from flask import request
|
||||
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
|
||||
from flask_session.sessions import total_seconds
|
||||
from itsdangerous import want_bytes
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
sess = Session()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
|
||||
app,
|
||||
db,
|
||||
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
|
||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
||||
app.config.get('SESSION_USE_SIGNER', False),
|
||||
app.config.get('SESSION_PERMANENT', True)
|
||||
)
|
||||
|
||||
session_type = app.config.get('SESSION_TYPE')
|
||||
if session_type == 'sqlalchemy':
|
||||
app.session_interface = sqlalchemy_session_interface
|
||||
elif session_type == 'redis':
|
||||
connection_class = Connection
|
||||
if app.config.get('SESSION_REDIS_USE_SSL', False):
|
||||
connection_class = SSLConnection
|
||||
|
||||
sess_redis_client = redis.Redis()
|
||||
sess_redis_client.connection_pool = redis.ConnectionPool(**{
|
||||
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
|
||||
'port': app.config.get('SESSION_REDIS_PORT', 6379),
|
||||
'username': app.config.get('SESSION_REDIS_USERNAME', None),
|
||||
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
|
||||
'db': app.config.get('SESSION_REDIS_DB', 2),
|
||||
'encoding': 'utf-8',
|
||||
'encoding_errors': 'strict',
|
||||
'decode_responses': False
|
||||
}, connection_class=connection_class)
|
||||
|
||||
app.extensions['session_redis'] = sess_redis_client
|
||||
|
||||
app.session_interface = CustomRedisSessionInterface(
|
||||
sess_redis_client,
|
||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
||||
app.config.get('SESSION_USE_SIGNER', False),
|
||||
app.config.get('SESSION_PERMANENT', True)
|
||||
)
|
||||
|
||||
|
||||
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
db,
|
||||
table,
|
||||
key_prefix,
|
||||
use_signer=False,
|
||||
permanent=True,
|
||||
sequence=None,
|
||||
autodelete=False,
|
||||
):
|
||||
if db is None:
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
db = SQLAlchemy(app)
|
||||
self.db = db
|
||||
self.key_prefix = key_prefix
|
||||
self.use_signer = use_signer
|
||||
self.permanent = permanent
|
||||
self.autodelete = autodelete
|
||||
self.sequence = sequence
|
||||
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
|
||||
|
||||
class Session(self.db.Model):
|
||||
__tablename__ = table
|
||||
|
||||
if sequence:
|
||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
||||
self.db.Integer, self.db.Sequence(sequence), primary_key=True
|
||||
)
|
||||
else:
|
||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
||||
self.db.Integer, primary_key=True
|
||||
)
|
||||
|
||||
session_id = self.db.Column(self.db.String(255), unique=True)
|
||||
data = self.db.Column(self.db.LargeBinary)
|
||||
expiry = self.db.Column(self.db.DateTime)
|
||||
|
||||
def __init__(self, session_id, data, expiry):
|
||||
self.session_id = session_id
|
||||
self.data = data
|
||||
self.expiry = expiry
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Session data {self.data}>"
|
||||
|
||||
self.sql_session_model = Session
|
||||
|
||||
def save_session(self, *args, **kwargs):
|
||||
if request.blueprint == 'service_api':
|
||||
return
|
||||
elif request.method == 'OPTIONS':
|
||||
return
|
||||
elif request.endpoint and request.endpoint == 'health':
|
||||
return
|
||||
return super().save_session(*args, **kwargs)
|
||||
|
||||
|
||||
class CustomRedisSessionInterface(RedisSessionInterface):
|
||||
|
||||
def save_session(self, app, session, response):
|
||||
if request.blueprint == 'service_api':
|
||||
return
|
||||
elif request.method == 'OPTIONS':
|
||||
return
|
||||
elif request.endpoint and request.endpoint == 'health':
|
||||
return
|
||||
|
||||
if not self.should_set_cookie(app, session):
|
||||
return
|
||||
domain = self.get_cookie_domain(app)
|
||||
path = self.get_cookie_path(app)
|
||||
if not session:
|
||||
if session.modified:
|
||||
self.redis.delete(self.key_prefix + session.sid)
|
||||
response.delete_cookie(
|
||||
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
|
||||
)
|
||||
return
|
||||
|
||||
# Modification case. There are upsides and downsides to
|
||||
# emitting a set-cookie header each request. The behavior
|
||||
# is controlled by the :meth:`should_set_cookie` method
|
||||
# which performs a quick check to figure out if the cookie
|
||||
# should be set or not. This is controlled by the
|
||||
# SESSION_REFRESH_EACH_REQUEST config flag as well as
|
||||
# the permanent flag on the session itself.
|
||||
# if not self.should_set_cookie(app, session):
|
||||
# return
|
||||
conditional_cookie_kwargs = {}
|
||||
httponly = self.get_cookie_httponly(app)
|
||||
secure = self.get_cookie_secure(app)
|
||||
if self.has_same_site_capability:
|
||||
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
|
||||
expires = self.get_expiration_time(app, session)
|
||||
|
||||
if session.permanent:
|
||||
value = self.serializer.dumps(dict(session))
|
||||
if value is not None:
|
||||
self.redis.setex(
|
||||
name=self.key_prefix + session.sid,
|
||||
value=value,
|
||||
time=total_seconds(app.permanent_session_lifetime),
|
||||
)
|
||||
|
||||
if self.use_signer:
|
||||
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
|
||||
else:
|
||||
session_id = session.sid
|
||||
response.set_cookie(
|
||||
app.config["SESSION_COOKIE_NAME"],
|
||||
session_id,
|
||||
expires=expires,
|
||||
httponly=httponly,
|
||||
domain=domain,
|
||||
path=path,
|
||||
secure=secure,
|
||||
**conditional_cookie_kwargs,
|
||||
)
|
||||
0
api/fields/__init__.py
Normal file
0
api/fields/__init__.py
Normal file
138
api/fields/app_fields.py
Normal file
138
api/fields/app_fields.py
Normal file
@ -0,0 +1,138 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
app_detail_kernel_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||
'total': fields.Integer,
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'dataset_query_variable': fields.String,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
model_config_partial_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
||||
}
|
||||
|
||||
template_fields = {
|
||||
'name': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'mode': fields.String,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
'data': fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
site_fields = {
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean,
|
||||
'app_base_url': fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'site': fields.Nested(site_fields),
|
||||
'api_base_url': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_site_fields = {
|
||||
'app_id': fields.String,
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
182
api/fields/conversation_fields.py
Normal file
182
api/fields/conversation_fields.py
Normal file
@ -0,0 +1,182 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]['text'] if value else ''
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
'like': fields.Integer,
|
||||
'dislike': fields.Integer
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'model': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw,
|
||||
}
|
||||
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': MessageTextField,
|
||||
'answer': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
conversation_message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
|
||||
}
|
||||
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
simple_conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(simple_conversation_fields))
|
||||
}
|
||||
|
||||
conversation_with_model_config_fields = {
|
||||
**simple_conversation_fields,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_with_model_config_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
|
||||
}
|
||||
65
api/fields/data_source_fields.py
Normal file
65
api/fields/data_source_fields.py
Normal file
@ -0,0 +1,65 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
43
api/fields/dataset_fields.py
Normal file
43
api/fields/dataset_fields.py
Normal file
@ -0,0 +1,43 @@
|
||||
from flask_restful import fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
dataset_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'provider': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'app_count': fields.Integer,
|
||||
'document_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
76
api/fields/document_fields.py
Normal file
76
api/fields/document_fields.py
Normal file
@ -0,0 +1,76 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user