Compare commits

...

35 Commits

Author SHA1 Message Date
43741ad5d1 feat: bump version to 0.3.34 (#1788) 2023-12-19 14:08:47 +08:00
8dec406161 chore: enchance ext name (#1787) 2023-12-19 14:03:24 +08:00
58f8d74591 fix: unstructured file extension (#1785) 2023-12-19 12:09:48 +08:00
867fc61b12 fix: web app text (#1784) 2023-12-19 11:45:16 +08:00
8e2e477a7f chore: enchance annotation ui (#1781) 2023-12-19 10:25:54 +08:00
9b34f5a9ff feat: unstructured frontend (#1777) 2023-12-18 23:28:25 +08:00
5e34f938c1 Feat/add unstructured support (#1780)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 23:24:06 +08:00
2fd56cb01c Fix/vdb index issue (#1776)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 21:33:54 +08:00
4f0e272549 fix: add then eidt annotion cause show bug (#1775) 2023-12-18 19:33:48 +08:00
1a5279a3ef fix: get billing info in self-hosted edition from current workspace (#1774) 2023-12-18 17:54:16 +08:00
7775f5785f chore: update annotation reply english i18n (#1773) 2023-12-18 17:13:15 +08:00
2de73991ff feat: only tenant owner can subscription. (#1770) 2023-12-18 16:59:31 +08:00
354d033e60 fix: not owner can not pay (#1772) 2023-12-18 16:54:47 +08:00
ebc2cdad2e fix annotation query exception (#1771)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 16:48:34 +08:00
5bb841935e feat: custom webapp logo (#1766) 2023-12-18 16:25:37 +08:00
65fd4b39ce feat: annotation management frontend (#1764) 2023-12-18 15:41:24 +08:00
96d2de2258 fix annotation reply in universal chat (#1768)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 15:04:17 +08:00
a71f2863ac Annotation management (#1767)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 13:10:05 +08:00
a9b942981d fix: issue templates not render correctly (#1763) 2023-12-18 09:22:11 +08:00
4b1ba2ec21 feat: remove billing config. (#1761) 2023-12-17 13:22:45 +08:00
c09184fd94 update bm25 search properties (#1758)
Co-authored-by: Blade <zhangxiaobin@unixyz.cn>
2023-12-15 12:28:03 +08:00
b0d8d196e1 azure openai add gpt-4-1106-preview、gpt-4-vision-preview models (#1751)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-14 09:55:30 +08:00
7c43123956 feat: can replace logo. (#1752) 2023-12-13 20:21:39 +08:00
eede84eb9e feat: web app support some feature (#1753) 2023-12-13 20:21:11 +08:00
b5b20234e9 feat: update pricing (#1749) 2023-12-13 16:41:40 +08:00
5beb298e47 chore: update term links (#1748) 2023-12-13 15:12:27 +08:00
6b499b9a16 remove stripe and anthropic. (#1746) 2023-12-12 17:59:07 +08:00
4c639961f5 add self checks to issues and discussions templates (#1742) 2023-12-11 23:25:17 +08:00
dfd3f507fb fix: ad block disabled tracking would block ga then can not pay (#1741) 2023-12-11 16:36:58 +08:00
d5695b3170 check rerank document is not empty (#1740)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 16:12:11 +08:00
994fceece3 fix: qa regex (#1738) 2023-12-11 15:53:37 +08:00
8c451eb0e6 fix only full text search in app issue (#1736)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 15:34:29 +08:00
79b4366203 fix: server component use translate errorts lint error (#1732) 2023-12-11 10:15:49 +08:00
3675d2eae8 fix: prompt null parse var error (#1731) 2023-12-11 10:06:01 +08:00
38b55d2186 fix: default types (#1728) 2023-12-09 23:38:07 +08:00
261 changed files with 9025 additions and 1156 deletions

View File

@ -1,11 +1,18 @@
name: "🕷️ Bug report"
description: Report errors or unexpected behavior [please use English :]
description: Report errors or unexpected behavior
labels:
- bug
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version

View File

@ -1,8 +1,16 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation [please use English :]
description: Report issues in our documentation
labels:
- ducumentation
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of requested docs changes

View File

@ -1,8 +1,17 @@
name: "⭐ Feature or enhancement request"
description: Propose something new. [please use English :]
description: Propose something new.
labels:
- enhancement
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Description of the new feature / enhancement

View File

@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :]
labels:
- help-wanted
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of the help you need

View File

@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :]
labels:
- translation
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version

View File

@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
@ -119,16 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
# Stripe configuration
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Billing configuration
BILLING_API_URL=http://127.0.0.1:8000/v1
BILLING_API_SECRET_KEY=
STRIPE_WEBHOOK_BILLING_SECRET=
ETL_TYPE=dify
UNSTRUCTURED_API_URL=

View File

@ -20,7 +20,7 @@ from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension
from extensions.ext_database import db
from extensions.ext_login import login_manager
@ -96,7 +96,6 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
# Flask-Login configuration

View File

@ -28,7 +28,7 @@ from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
import secrets
import base64
@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@ -766,3 +790,4 @@ def register_commands(app):
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)

View File

@ -1,11 +1,8 @@
# -*- coding:utf-8 -*-
import os
from datetime import timedelta
import dotenv
from extensions.ext_database import db
from extensions.ext_redis import redis_client
dotenv.load_dotenv()
@ -44,15 +41,11 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'CLEAN_DAY_SETTING': 30,
@ -61,7 +54,8 @@ DEFAULTS = {
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
'INVITE_EXPIRY_HOURS': 72
'INVITE_EXPIRY_HOURS': 72,
'ETL_TYPE': 'dify',
}
@ -91,7 +85,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.33"
self.CURRENT_VERSION = "0.3.34"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -268,8 +262,6 @@ class Config:
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@ -281,14 +273,13 @@ class Config:
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
class CloudEditionConfig(Config):
@ -302,6 +293,3 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import extension, setup, version, apikey, admin
# Import app controllers
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe
from .billing import billing

View File

@ -0,0 +1,290 @@
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, marshal
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
annotation_hit_history_fields
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from flask import request
class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args()
if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError('Unsupported annotation reply action')
return result, 200
class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
response = {
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')

View File

@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400

View File

@ -6,22 +6,23 @@ from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get app info
app = _get_app(app_id)
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('content', type=str, location='json')
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
message_id = str(args['message_id'])
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
if annotation:
annotation.content = args['content']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['content'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
return {'result': 'success'}
return annotation
class MessageAnnotationCountApi(Resource):

View File

@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
"""Modify app model config"""
app_id = str(app_id)
app_model = _get_app(app_id)
app = _get_app(app_id)
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
mode=app.mode
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
app_id=app.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(
app_model,
app,
app_model_config=new_app_model_config
)

View File

@ -1,9 +1,6 @@
import stripe
import os
from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app, request
from flask import current_app
from controllers.console import api
from controllers.console.setup import setup_required
@ -40,7 +37,12 @@ class Subscription(Resource):
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
BillingService.is_tenant_owner(current_user)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)
class Invoices(Resource):
@ -50,36 +52,10 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
BillingService.is_tenant_owner(current_user)
return BillingService.get_invoices(current_user.email)
class StripeBillingWebhook(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
BillingService.process_event(event)
return 'success', 200
api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

View File

@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
return {'content': text}
class FileeSupportTypApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
etl_type = current_app.config['ETL_TYPE']
if etl_type == 'Unstructured':
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
else:
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
return {'allowed_extensions': allowed_extensions}
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileeSupportTypApi, '/files/support-type')

View File

@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw
}
@marshal_with(parameters_fields)
@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
}

View File

@ -1,61 +0,0 @@
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')

View File

@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_checkout_service import ProviderCheckoutService
from services.provider_service import ProviderService
from services.billing_service import BillingService
class ModelProviderListApi(Resource):
@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
if provider_name != 'anthropic':
raise ValueError(f'provider name {provider_name} is invalid')
return {
'url': provider_checkout.get_checkout_url()
}
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
return data
class ModelProviderFreeQuotaSubmitApi(Resource):

View File

@ -10,12 +10,15 @@ from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Tenant
import services
from services.account_service import TenantService
from services.workspace_service import WorkspaceService
from services.file_service import FileService
provider_fields = {
'provider_name': fields.String,
@ -34,6 +37,7 @@ tenant_fields = {
'providers': fields.List(fields.Nested(provider_fields)),
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'),
}
tenants_fields = {
@ -130,6 +134,61 @@ class SwitchWorkspaceApi(Resource):
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('remove_webapp_brand', type=bool, location='json')
parser.add_argument('replace_webapp_logo', type=str, location='json')
args = parser.parse_args()
custom_config_dict = {
'remove_webapp_brand': args['remove_webapp_brand'],
'replace_webapp_logo': args['replace_webapp_logo'],
}
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.custom_config_dict = custom_config_dict
db.session.commit()
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
extension = file.filename.split('.')[-1]
if extension.lower() not in ['svg', 'png']:
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file, current_user, True)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return { 'id': upload_file.id }, 201
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
@ -137,3 +196,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')

View File

@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
annotation_quota_limit = billing_info['annotation_quota_limit']
if resource == 'members' and 0 < members['limit'] <= members['size']:
abort(403, error_msg)
@ -62,6 +63,10 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
abort(403, error_msg)
elif resource == 'workspace_custom' and not billing_info['can_replace_logo']:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']:
abort(403, error_msg)
else:
return view(*args, **kwargs)

View File

@ -1,10 +1,12 @@
from flask import request, Response
from flask_restful import Resource
from werkzeug.exceptions import NotFound
import services
from controllers.files import api
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.account_service import TenantService
class ImagePreviewApi(Resource):
@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class WorkspaceWebappLogoApi(Resource):
def get(self, workspace_id):
workspace_id = str(workspace_id)
custom_config = TenantService.get_custom_config(workspace_id)
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
if not webapp_logo_file_id:
raise NotFound(f'webapp logo is not found')
try:
generator, mimetype = FileService.get_public_image_preview(
webapp_logo_file_id,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
class UnsupportedFileTypeError(BaseHTTPException):

View File

@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
args = parser.parse_args()

View File

@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -1,11 +1,15 @@
# -*- coding:utf-8 -*-
import os
from flask_restful import fields, marshal_with
from flask import current_app
from werkzeug.exceptions import Forbidden
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.model import Site
from services.billing_service import BillingService
class AppSiteApi(WebApiResource):
@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
'site': fields.Nested(site_fields),
'model_config': fields.Nested(model_config_fields, allow_null=True),
'plan': fields.String,
'can_replace_logo': fields.Boolean,
'custom_config': fields.Raw(attribute='custom_config'),
}
@marshal_with(app_fields)
@ -50,7 +56,14 @@ class AppSiteApi(WebApiResource):
if not site:
raise Forbidden()
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
edition = os.environ.get('EDITION')
can_replace_logo = False
if edition == 'CLOUD':
info = BillingService.get_info(app_model.tenant_id)
can_replace_logo = info['can_replace_logo']
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
api.add_resource(AppSiteApi, '/site')
@ -59,7 +72,7 @@ api.add_resource(AppSiteApi, '/site')
class AppSiteInfo:
"""Class to store site information."""
def __init__(self, tenant, app, site, end_user):
def __init__(self, tenant, app, site, end_user, can_replace_logo):
"""Initialize AppSiteInfo instance."""
self.app_id = app.id
self.end_user_id = end_user
@ -67,6 +80,16 @@ class AppSiteInfo:
self.site = site
self.model_config = None
self.plan = tenant.plan
self.can_replace_logo = can_replace_logo
if can_replace_logo:
base_url = current_app.config.get('FILES_URL')
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
self.custom_config = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
}
if app.enable_site and site.prompt_public:
app_model_config = app.app_model_config

View File

@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion:
@ -33,7 +38,7 @@ class Completion:
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
auto_generate_name: bool = True, from_source: str = 'console'):
"""
errors: ProviderTokenNotInitError
"""
@ -109,7 +114,10 @@ class Completion:
fake_response=str(e)
)
return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
@ -166,17 +174,18 @@ class Completion:
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
@ -324,6 +333,81 @@ class Completion:
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
try:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
except Exception as e:
logging.warning(f'Query annotation failed, exception: {str(e)}.')
return False
return False
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,

View File

@ -319,6 +319,10 @@ class ConversationMessageTask:
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str,
@ -435,7 +439,7 @@ class PubHandler:
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': self._conversation.id,
}
}
if retriever_resource:
@ -446,6 +450,30 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',

View File

@ -3,7 +3,8 @@ from pathlib import Path
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from flask import current_app
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
@ -49,14 +57,34 @@ class FileExtractor:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)

View File

@ -0,0 +1,41 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredEmailLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,48 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMarkdownLoader(BaseLoader):
"""Load md files.
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMsgLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTXLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredTextLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredXmlLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -10,7 +10,7 @@ from flask import current_app
from extensions.ext_storage import storage
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
class UploadFileParser:

View File

@ -32,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError

View File

@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)

View File

@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
@ -121,6 +120,16 @@ class MilvusVectorIndex(BaseVectorIndex):
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
@ -33,7 +38,8 @@ class VectorIndex:
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

View File

@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
@ -141,6 +142,27 @@ class WeaviateVectorIndex(BaseVectorIndex):
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@ -397,7 +397,7 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=False)
text_docs = FileExtractor.load(file_detail, is_automatic=True)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
@ -632,8 +632,8 @@ class IndexingRunner:
return text
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [
{

View File

@ -23,7 +23,8 @@ FUNCTION_CALL_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k'
'gpt-35-turbo-16k',
'gpt-4-1106-preview'
]
class AzureOpenAIModel(BaseLLM):

View File

@ -24,7 +24,10 @@ class CohereReranking(BaseReranking):
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> \
Optional[List[Document]]:
if not documents:
return []
docs = []
doc_id = []
unique_documents = []
@ -34,7 +37,7 @@ class CohereReranking(BaseReranking):
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = []

View File

@ -21,6 +21,8 @@ class XinferenceReranking(BaseReranking):
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
if not documents:
return []
docs = []
doc_id = []
unique_documents = []

View File

@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):
return False
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
if hosted_model_providers.anthropic \
and hosted_model_providers.anthropic.paid_enabled:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

View File

@ -122,6 +122,22 @@ class AzureOpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-1106-preview',
'name': 'gpt-4-1106-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-vision-preview',
'name': 'gpt-4-vision-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.VISION.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
@ -171,6 +187,8 @@ class AzureOpenAIProvider(BaseModelProvider):
base_model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-4-1106-preview': 4096,
'gpt-4-vision-preview': 4096,
'gpt-35-turbo': 4096,
'gpt-35-turbo-16k': 16384,
'text-davinci-003': 4097,
@ -376,6 +394,18 @@ class AzureOpenAIProvider(BaseModelProvider):
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4-1106-preview',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4-vision-preview',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-davinci-003',
model_type=ModelType.TEXT_GENERATION,

View File

@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
).update({'last_used': datetime.utcnow()})
db.session.commit()
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
return None
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.

View File

@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
quota_limit: int = 0
"""Quota limit for the openai hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedAzureOpenAI(BaseModel):
@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100
class HostedModelProviders(BaseModel):
@ -68,8 +62,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
)
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
@ -85,10 +77,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):

View File

@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):
return False
def get_payment_info(self) -> Optional[dict]:
"""
get payment info if it payable.
:return:
"""
if hosted_model_providers.openai \
and hosted_model_providers.openai.paid_enabled:
return {
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

View File

@ -21,6 +21,18 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-4-1106-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-vision-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.002",
"completion": "0.0015",

View File

@ -1,11 +1,13 @@
from typing import Dict, Any, Optional, List, Tuple, Union
from typing import Dict, Any, Optional, List, Tuple, Union, cast
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
from pydantic import root_validator
from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [self._convert_message_to_dict(m) for m in messages]
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
return self._create_chat_result(response)
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, LCHumanMessageWithFiles):
content = [
{
"type": "text",
"text": message.content
}
]
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
content.append({
"type": "image_url",
"image_url": {
"url": file.data,
"detail": file.detail.value
}
})
message_dict = {"role": "user", "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

View File

@ -82,7 +82,8 @@ class DatasetMultiRetrieverTool(BaseTool):
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]

View File

@ -158,7 +158,8 @@ class DatasetRetrieverTool(BaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,

View File

@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
else:
return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',

View File

@ -243,7 +243,7 @@ class Weaviate(VectorStore):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text', 'dataset_id', 'doc_hash', 'doc_id', 'document_id']
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")

View File

@ -6,13 +6,13 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect
def handle(sender, **kwargs):
app_model = sender
app = sender
app_model_config = kwargs.get('app_model_config')
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id
AppDatasetJoin.app_id == app.id
).all()
removed_dataset_ids = []
@ -29,14 +29,14 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app_model.id,
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app_model.id,
app_id=app.id,
dataset_id=dataset_id
)
db.session.add(app_dataset_join)

View File

@ -1,6 +0,0 @@
import stripe
def init_app(app):
if app.config.get('STRIPE_API_KEY'):
stripe.api_key = app.config.get('STRIPE_API_KEY')

View File

@ -0,0 +1,36 @@
from flask_restful import fields
from libs.helper import TimestampField
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute='content'),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(account_fields, allow_null=True)
}
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute='annotation_question'),
"response": fields.String(attribute='annotation_content')
}
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}

View File

@ -21,6 +21,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'external_data_tools': fields.Raw(attribute='external_data_tools_list'),

View File

@ -23,11 +23,18 @@ feedback_fields = {
}
annotation_fields = {
'id': fields.String,
'question': fields.String,
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
annotation_hit_history_fields = {
'annotation_id': fields.String,
'annotation_create_account': fields.Nested(account_fields, allow_null=True)
}
message_file_fields = {
'id': fields.String,
'type': fields.String,
@ -49,6 +56,7 @@ message_detail_fields = {
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
'created_at': TimestampField,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
}

View File

@ -0,0 +1,50 @@
"""add_app_anntation_setting
Revision ID: 246ba09cbbdb
Revises: 714aafe25d39
Create Date: 2023-12-14 11:26:12.287264
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '246ba09cbbdb'
down_revision = '714aafe25d39'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_settings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False),
sa.Column('collection_binding_id', postgresql.UUID(), nullable=False),
sa.Column('created_user_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_user_id', postgresql.UUID(), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey')
)
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')
op.drop_table('app_annotation_settings')
# ### end Alembic commands ###

View File

@ -0,0 +1,32 @@
"""add-annotation-histoiry-score
Revision ID: 46976cc39132
Revises: e1901f623fd0
Create Date: 2023-12-13 04:39:59.302971
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '46976cc39132'
down_revision = 'e1901f623fd0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('score')
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""add_anntation_history_match_response
Revision ID: 714aafe25d39
Revises: f2a6fc85e260
Create Date: 2023-12-14 06:38:02.972527
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_column('annotation_content')
batch_op.drop_column('annotation_question')
# ### end Alembic commands ###

View File

@ -0,0 +1,32 @@
"""add custom config in tenant
Revision ID: 88072f0caa04
Revises: fca025d3b60f
Create Date: 2023-12-14 07:36:50.705362
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '88072f0caa04'
down_revision = '246ba09cbbdb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.drop_column('custom_config')
# ### end Alembic commands ###

View File

@ -0,0 +1,79 @@
"""add-annotation-reply
Revision ID: e1901f623fd0
Revises: fca025d3b60f
Create Date: 2023-12-12 06:58:41.054544
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'e1901f623fd0'
down_revision = 'fca025d3b60f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_annotation_hit_histories',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', postgresql.UUID(), nullable=False),
sa.Column('annotation_id', postgresql.UUID(), nullable=False),
sa.Column('source', sa.Text(), nullable=False),
sa.Column('question', sa.Text(), nullable=False),
sa.Column('account_id', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey')
)
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False))
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('annotation_reply')
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_app_idx')
batch_op.drop_index('app_annotation_hit_histories_annotation_idx')
batch_op.drop_index('app_annotation_hit_histories_account_idx')
op.drop_table('app_annotation_hit_histories')
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""add_anntation_history_message_id
Revision ID: f2a6fc85e260
Revises: 46976cc39132
Create Date: 2023-12-13 11:09:29.329584
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.drop_index('app_annotation_hit_histories_message_idx')
batch_op.drop_column('message_id')
# ### end Alembic commands ###

View File

@ -1,4 +1,6 @@
import json
import enum
from math import e
from typing import List
from flask_login import UserMixin
@ -112,6 +114,7 @@ class Tenant(db.Model):
encrypt_public_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
custom_config = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -121,6 +124,14 @@ class Tenant(db.Model):
Account.id == TenantAccountJoin.account_id,
TenantAccountJoin.tenant_id == self.id
).all()
@property
def custom_config_dict(self) -> dict:
return json.loads(self.custom_config) if self.custom_config else {}
@custom_config_dict.setter
def custom_config_dict(self, value: dict):
self.custom_config = json.dumps(value)
class TenantAccountJoinRole(enum.Enum):

View File

@ -135,7 +135,7 @@ class DatasetProcessRule(db.Model):
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 512
'max_tokens': 1000
}
}
@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model):
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -2,6 +2,7 @@ import json
from flask import current_app, request
from flask_login import UserMixin
from sqlalchemy import Float
from sqlalchemy.dialects.postgresql import UUID
from core.file.upload_file_parser import UploadFileParser
@ -128,6 +129,25 @@ class AppModelConfig(db.Model):
return json.loads(self.retriever_resource) if self.retriever_resource \
else {"enabled": False}
@property
def annotation_reply_dict(self) -> dict:
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == self.app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
else:
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@ -170,7 +190,9 @@ class AppModelConfig(db.Model):
@property
def file_upload_dict(self) -> dict:
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
return json.loads(self.file_upload) if self.file_upload else {
"image": {"enabled": False, "number_limits": 3, "detail": "high",
"transfer_methods": ["remote_url", "local_file"]}}
def to_dict(self) -> dict:
return {
@ -182,6 +204,7 @@ class AppModelConfig(db.Model):
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource_dict,
"annotation_reply": self.annotation_reply_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"external_data_tools": self.external_data_tools_list,
@ -504,6 +527,12 @@ class Message(db.Model):
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
return annotation
@property
def annotation_hit_history(self):
annotation_history = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.message_id == self.id).first())
return annotation_history
@property
def app_model_config(self):
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
@ -616,9 +645,11 @@ class MessageAnnotation(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
message_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(UUID, nullable=True)
question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -629,6 +660,79 @@ class MessageAnnotation(db.Model):
return account
class AppAnnotationHitHistory(db.Model):
__tablename__ = 'app_annotation_hit_histories'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'),
db.Index('app_annotation_hit_histories_app_idx', 'app_id'),
db.Index('app_annotation_hit_histories_account_idx', 'account_id'),
db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'),
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
annotation_id = db.Column(UUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
score = db.Column(Float, nullable=False, server_default=db.text('0'))
message_id = db.Column(UUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False)
annotation_content = db.Column(db.Text, nullable=False)
@property
def account(self):
account = (db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id).first())
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
return account
class AppAnnotationSetting(db.Model):
__tablename__ = 'app_annotation_settings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'),
db.Index('app_annotation_settings_app_idx', 'app_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
collection_binding_id = db.Column(UUID, nullable=False)
created_user_id = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_user_id = db.Column(UUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def created_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def updated_account(self):
account = (db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id).first())
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id).first())
return collection_binding_detail
class OperationLog(db.Model):
__tablename__ = 'operation_logs'
__table_args__ = (

View File

@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class ProviderOrderPaymentStatus(Enum):
WAIT_PAY = 'wait_pay'
PAID = 'paid'
PAY_FAILED = 'pay_failed'
REFUNDED = 'refunded'
@staticmethod
def value_of(value):
for member in ProviderOrderPaymentStatus:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderOrder(db.Model):
__tablename__ = 'provider_orders'
__table_args__ = (

View File

@ -46,7 +46,6 @@ websocket-client~=1.6.1
dashscope~=1.11.0
huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference-client~=0.6.4
safetensors==0.3.2
@ -54,4 +53,6 @@ zhipuai==1.0.7
werkzeug==2.3.7
pymilvus==2.3.0
qdrant-client==1.6.4
cohere~=4.32
cohere~=4.32
unstructured~=0.10.27
unstructured[docx,pptx]~=0.10.27

View File

@ -412,6 +412,12 @@ class TenantService:
db.session.delete(tenant)
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> None:
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404()
return tenant.custom_config_dict
class RegisterService:

View File

@ -0,0 +1,426 @@
import datetime
import json
import uuid
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
if 'message_id' in args and args['message_id']:
message_id = str(args['message_id'])
# get message info
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
# save the message annotation
if annotation:
annotation.content = args['answer']
annotation.question = args['question']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
else:
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
# async job
job_id = str(uuid.uuid4())
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
args['score_threshold'],
args['embedding_provider_name'], args['embedding_model_name'])
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
return {
'job_id': cache_result,
'job_status': 'processing'
}
# async job
job_id = str(uuid.uuid4())
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
if keyword:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
MessageAnnotation.content.ilike('%{}%'.format(keyword))
)
)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
else:
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
return annotations.items, annotations.total
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotations = (db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc()).all())
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = MessageAnnotation(
app_id=app.id,
content=args['answer'],
question=args['question'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
app_id, annotation_setting.collection_binding_id)
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation.content = args['answer']
annotation.question = args['question']
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
update_annotation_to_index_task.delay(annotation.id, annotation.question,
current_user.current_tenant_id,
app_id, app_annotation_setting.collection_binding_id)
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
delete_annotation_index_task.delay(annotation.id, app_id,
current_user.current_tenant_id,
app_annotation_setting.collection_binding_id)
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
content = {
'question': row[0],
'answer': row[1]
}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_import_annotations_task.delay(str(job_id), result, app_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {
'error_msg': str(e)
}
return {
'job_id': job_id,
'job_status': 'waiting'
}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
return annotation_hit_histories.items, annotation_hit_histories.total
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
return annotation
@classmethod
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
annotation_content: str, query: str, user_id: str,
message_id: str, from_source: str, score: float):
# add hit count to annotation
db.session.query(MessageAnnotation).filter(
MessageAnnotation.id == annotation_id
).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
synchronize_session=False
)
annotation_hit_history = AppAnnotationHitHistory(
annotation_id=annotation_id,
app_id=app_id,
account_id=user_id,
question=query,
source=from_source,
score=score,
message_id=message_id,
annotation_question=annotation_question,
annotation_content=annotation_content
)
db.session.add(annotation_hit_history)
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}
return {
"enabled": False
}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
).first()
if not annotation_setting:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args['score_threshold']
annotation_setting.updated_user_id = current_user.id
annotation_setting.updated_at = datetime.datetime.utcnow()
db.session.add(annotation_setting)
db.session.commit()
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name
}
}

View File

@ -138,7 +138,22 @@ class AppModelConfigService:
config["retriever_resource"]["enabled"] = False
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
raise ValueError("enabled in retriever_resource must be of boolean type")
# annotation reply
if 'annotation_reply' not in config or not config["annotation_reply"]:
config["annotation_reply"] = {
"enabled": False
}
if not isinstance(config["annotation_reply"], dict):
raise ValueError("annotation_reply must be of dict type")
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
config["annotation_reply"]["enabled"] = False
if not isinstance(config["annotation_reply"]["enabled"], bool):
raise ValueError("enabled in annotation_reply must be of boolean type")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
@ -325,6 +340,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"],
"annotation_reply": config["annotation_reply"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"],

View File

@ -1,6 +1,10 @@
import os
import requests
from extensions.ext_database import db
from models.account import TenantAccountJoin
class BillingService:
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
@ -10,7 +14,7 @@ class BillingService:
def get_info(cls, tenant_id: str):
params = {'tenant_id': tenant_id}
billing_info = cls._send_request('GET', '/info', params=params)
billing_info = cls._send_request('GET', '/subscription/info', params=params)
return billing_info
@ -18,16 +22,26 @@ class BillingService:
def get_subscription(cls, plan: str,
interval: str,
prefilled_email: str = '',
user_name: str = '',
tenant_id: str = ''):
params = {
'plan': plan,
'interval': interval,
'prefilled_email': prefilled_email,
'user_name': user_name,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/subscription', params=params)
return cls._send_request('GET', '/subscription/payment-link', params=params)
@classmethod
def get_model_provider_payment_link(cls,
provider_name: str,
tenant_id: str,
account_id: str):
params = {
'provider_name': provider_name,
'tenant_id': tenant_id,
'account_id': account_id
}
return cls._send_request('GET', '/model-provider/payment-link', params=params)
@classmethod
def get_invoices(cls, prefilled_email: str = ''):
@ -46,9 +60,14 @@ class BillingService:
return response.json()
@classmethod
def process_event(cls, event: dict):
json = {
"content": event,
}
return cls._send_request('POST', '/webhook/stripe', json=json)
@staticmethod
def is_tenant_owner(current_user):
tenant_id = current_user.current_tenant_id
join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == current_user.id
).first()
if join.role != 'owner':
raise ValueError('Only tenant owner can perform this action')

View File

@ -165,7 +165,8 @@ class CompletionService:
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
'auto_generate_name': auto_generate_name,
'from_source': from_source
})
generate_worker_thread.start()
@ -193,7 +194,7 @@ class CompletionService:
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True):
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
@ -218,7 +219,8 @@ class CompletionService:
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
auto_generate_name=auto_generate_name,
from_source=from_source
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
@ -385,6 +387,9 @@ class CompletionService:
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
@ -427,6 +432,9 @@ class CompletionService:
elif event == 'agent_thought':
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
@ -499,6 +507,25 @@ class CompletionService:
return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
@ -551,6 +578,23 @@ class CompletionService:
return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)

View File

@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@ -1175,10 +1172,12 @@ class SegmentService:
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
type=collection_type
)
db.session.add(dataset_collection_binding)
db.session.flush()
db.session.commit()
return dataset_collection_binding
@classmethod
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
return dataset_collection_binding

View File

@ -17,8 +17,8 @@ from models.model import UploadFile, EndUser
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
'jpg', 'jpeg', 'png', 'webp', 'gif']
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
PREVIEW_WORDS_LIMIT = 3000
@ -27,7 +27,13 @@ class FileService:
@staticmethod
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
etl_type = current_app.config['ETL_TYPE']
if etl_type == 'Unstructured':
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
else:
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
@ -154,3 +160,21 @@ class FileService:
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file.mime_type
@staticmethod
def get_public_image_preview(file_id: str) -> str:
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type

View File

@ -1,174 +0,0 @@
import datetime
import logging
import stripe
from flask import current_app
from core.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from models.account import Account
from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType
class ProviderCheckout:
def __init__(self, stripe_checkout_session):
self.stripe_checkout_session = stripe_checkout_session
def get_checkout_url(self):
return self.stripe_checkout_session.url
class ProviderCheckoutService:
def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
# check provider name is valid
model_provider_rules = ModelProviderFactory.get_provider_rules()
if provider_name not in model_provider_rules:
raise ValueError(f'provider name {provider_name} is invalid')
model_provider_rule = model_provider_rules[provider_name]
# check provider name can be paid
self._check_provider_payable(provider_name, model_provider_rule)
# get stripe checkout product id
paid_provider = self._get_paid_provider(tenant_id, provider_name)
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
model_provider = model_provider_class(provider=paid_provider)
payment_info = model_provider.get_payment_info()
if not payment_info:
raise ValueError(f'provider name {provider_name} not support payment')
payment_product_id = payment_info['product_id']
payment_min_quantity = payment_info['min_quantity']
payment_max_quantity = payment_info['max_quantity']
# create provider order
provider_order = ProviderOrder(
tenant_id=tenant_id,
provider_name=provider_name,
account_id=account.id,
payment_product_id=payment_product_id,
quantity=1,
payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
)
db.session.add(provider_order)
db.session.flush()
line_item = {
'price': f'{payment_product_id}',
'quantity': payment_min_quantity
}
if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
line_item['adjustable_quantity'] = {
'enabled': True,
'minimum': payment_min_quantity,
'maximum': payment_max_quantity
}
try:
# create stripe checkout session
checkout_session = stripe.checkout.Session.create(
line_items=[
line_item
],
mode='payment',
success_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=succeeded',
cancel_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=cancelled',
automatic_tax={'enabled': True},
)
except Exception as e:
logging.exception(e)
raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')
provider_order.payment_id = checkout_session.id
db.session.commit()
return ProviderCheckout(checkout_session)
def fulfill_provider_order(self, event, line_items):
provider_order = db.session.query(ProviderOrder) \
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
.first()
if not provider_order:
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
raise ValueError(
f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
provider_order.transaction_id = event['data']['object']['payment_intent']
provider_order.currency = event['data']['object']['currency']
provider_order.total_amount = event['data']['object']['amount_subtotal']
provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
provider_order.paid_at = datetime.datetime.utcnow()
provider_order.updated_at = provider_order.paid_at
# update provider quota
provider = db.session.query(Provider).filter(
Provider.tenant_id == provider_order.tenant_id,
Provider.provider_name == provider_order.provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.PAID.value
).first()
if not provider:
raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
f'provider name: {provider_order.provider_name}')
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
model_provider = model_provider_class(provider=provider)
payment_info = model_provider.get_payment_info()
quantity = line_items['data'][0]['quantity']
if not payment_info:
increase_quota = 0
else:
increase_quota = int(payment_info['increase_quota']) * quantity
if increase_quota > 0:
provider.quota_limit += increase_quota
provider.is_valid = True
db.session.commit()
def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
raise ValueError(f'provider name {provider_name} not support payment')
if 'system_config' not in model_provider_rule:
raise ValueError(f'provider name {provider_name} not support payment')
if 'supported_quota_types' not in model_provider_rule['system_config']:
raise ValueError(f'provider name {provider_name} not support payment')
if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
raise ValueError(f'provider name {provider_name} not support payment')
def _get_paid_provider(self, tenant_id: str, provider_name: str):
paid_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.PAID.value,
).first()
if not paid_provider:
paid_provider = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.PAID.value,
quota_limit=0,
quota_used=0,
)
db.session.add(paid_provider)
db.session.commit()
return paid_provider

View File

@ -1,8 +1,12 @@
from flask import current_app
from flask_login import current_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin
from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
from models.provider import Provider
from services.billing_service import BillingService
from services.account_service import TenantService
class WorkspaceService:
@classmethod
@ -28,6 +32,13 @@ class WorkspaceService:
).first()
tenant_info['role'] = tenant_account_join.role
edition = current_app.config['EDITION']
if edition == 'CLOUD':
billing_info = BillingService.get_info(tenant_info['id'])
if billing_info['can_replace_logo'] and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
tenant_info['custom_config'] = tenant.custom_config_dict
# Get providers
providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id

View File

@ -0,0 +1,61 @@
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Add annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")

View File

@ -0,0 +1,99 @@
import json
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
user_id: str):
"""
Add annotation to index.
:param job_id: job_id
:param content_list: content list
:param tenant_id: tenant id
:param app_id: app id
:param user_id: user_id
"""
logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
start_at = time.perf_counter()
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id,
content=content['answer'],
question=content['question'],
account_id=user_id
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content['question'],
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if app_annotation_setting:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id,
'annotation'
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
fg='green'))
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, 'error')
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
redis_client.setex(indexing_error_msg_key, 600, str(e))
logging.exception("Build index for batch import annotations failed")

View File

@ -0,0 +1,45 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str,
collection_binding_id: str):
"""
Async delete annotation index task
"""
logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=dataset_collection_binding.id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('annotation_id', annotation_id)
except Exception:
logging.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation deleted index failed:{}".format(str(e)))

View File

@ -0,0 +1,74 @@
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
@shared_task(queue='dataset')
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id
).first()
if not app_annotation_setting:
raise NotFound("App annotation setting not found")
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
collection_binding_id=app_annotation_setting.collection_binding_id
)
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
if vector_index:
try:
vector_index.delete_by_metadata_field('app_id', app_id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch deleted index failed:{}".format(str(e)))
redis_client.setex(disable_app_annotation_job_key, 600, 'error')
disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)

View File

@ -0,0 +1,106 @@
import datetime
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import MessageAnnotation, App, AppAnnotationSetting
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
embedding_provider_name: str, embedding_model_name: str):
"""
Async enable annotation reply task
"""
logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = datetime.datetime.utcnow()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id
)
db.session.add(new_app_annotation_setting)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app_id,
"doc_id": annotation.id
}
)
documents.append(document)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
try:
index.delete_by_metadata_field('app_id', app_id)
except Exception as e:
logging.info(
click.style('Delete annotation index error: {}'.format(str(e)),
fg='red'))
index.add_texts(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
end_at = time.perf_counter()
logging.info(
click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
fg='green'))
except Exception as e:
logging.exception("Annotation batch created index failed:{}".format(str(e)))
redis_client.setex(enable_app_annotation_job_key, 600, 'error')
enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)

View File

@ -0,0 +1,63 @@
import logging
import time
import click
from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@shared_task(queue='dataset')
def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
collection_binding_id: str):
"""
Update annotation to index.
:param annotation_id: annotation id
:param question: question
:param tenant_id: tenant id
:param app_id: app id
:param collection_binding_id: embedding binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green'))
start_at = time.perf_counter()
try:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id,
'annotation'
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
document = Document(
page_content=question,
metadata={
"annotation_id": annotation_id,
"app_id": app_id,
"doc_id": annotation_id
}
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.delete_by_metadata_field('annotation_id', annotation_id)
index.add_texts([document])
end_at = time.perf_counter()
logging.info(
click.style(
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Build index for annotation failed")

View File

@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.33
image: langgenius/dify-api:0.3.34
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -128,7 +128,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.33
image: langgenius/dify-api:0.3.34
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -196,7 +196,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.33
image: langgenius/dify-web:0.3.34
restart: always
environment:
EDITION: SELF_HOSTED

View File

@ -0,0 +1,17 @@
import React from 'react'
import Main from '@/app/components/app/log-annotation'
import { PageType } from '@/app/components/app/configuration/toolbox/annotation/type'
export type IProps = {
params: { appId: string }
}
const Logs = async ({
params: { appId },
}: IProps) => {
return (
<Main pageType={PageType.annotation} appId={appId} />
)
}
export default Logs

View File

@ -1,5 +1,6 @@
import React from 'react'
import Main from '@/app/components/app/log'
import Main from '@/app/components/app/log-annotation'
import { PageType } from '@/app/components/app/configuration/toolbox/annotation/type'
export type IProps = {
params: { appId: string }
@ -9,7 +10,7 @@ const Logs = async ({
params: { appId },
}: IProps) => {
return (
<Main appId={appId} />
<Main pageType={PageType.log} appId={appId} />
)
}

View File

@ -2,11 +2,11 @@ import classNames from 'classnames'
import style from '../list.module.css'
import Apps from './Apps'
import { getLocaleOnServer } from '@/i18n/server'
import { useTranslation } from '@/i18n/i18next-serverside-config'
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
const AppList = async () => {
const locale = getLocaleOnServer()
const { t } = await useTranslation(locale, 'app')
const { t } = await translate(locale, 'app')
return (
<div className='flex flex-col overflow-auto bg-gray-100 shrink-0 grow'>

View File

@ -1,6 +1,6 @@
import React from 'react'
import { getLocaleOnServer } from '@/i18n/server'
import { useTranslation } from '@/i18n/i18next-serverside-config'
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
import Form from '@/app/components/datasets/settings/form'
type Props = {
@ -11,8 +11,7 @@ const Settings = async ({
params: { datasetId },
}: Props) => {
const locale = getLocaleOnServer()
// eslint-disable-next-line react-hooks/rules-of-hooks
const { t } = await useTranslation(locale, 'dataset-settings')
const { t } = await translate(locale, 'dataset-settings')
return (
<div className='bg-white h-full overflow-y-auto'>

View File

@ -28,7 +28,15 @@ export default function NavLink({
mode = 'expand',
}: NavLinkProps) {
const segment = useSelectedLayoutSegment()
const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase()
const formattedSegment = (() => {
let res = segment?.toLowerCase()
// logs and annotations use the same nav
if (res === 'annotations')
res = 'logs'
return res
})()
const isActive = href.toLowerCase().split('/')?.pop() === formattedSegment
const NavIcon = isActive ? iconMap.selected : iconMap.normal
return (

View File

@ -0,0 +1,47 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import Textarea from 'rc-textarea'
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
export enum EditItemType {
Query = 'query',
Answer = 'answer',
}
type Props = {
type: EditItemType
content: string
onChange: (content: string) => void
}
const EditItem: FC<Props> = ({
type,
content,
onChange,
}) => {
const { t } = useTranslation()
const avatar = type === EditItemType.Query ? <User className='w-6 h-6' /> : <Robot className='w-6 h-6' />
const name = type === EditItemType.Query ? t('appAnnotation.addModal.queryName') : t('appAnnotation.addModal.answerName')
const placeholder = type === EditItemType.Query ? t('appAnnotation.addModal.queryPlaceholder') : t('appAnnotation.addModal.answerPlaceholder')
return (
<div className='flex' onClick={e => e.stopPropagation()}>
<div className='shrink-0 mr-3'>
{avatar}
</div>
<div className='grow'>
<div className='mb-1 leading-[18px] text-xs font-semibold text-gray-900'>{name}</div>
<Textarea
className='mt-1 block w-full leading-5 max-h-none text-sm text-gray-700 outline-none appearance-none resize-none'
value={content}
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => onChange(e.target.value)}
autoSize={{ minRows: 3 }}
placeholder={placeholder}
autoFocus
/>
</div>
</div>
)
}
export default React.memo(EditItem)

Some files were not shown because too many files have changed in this diff Show More