mirror of
https://github.com/langgenius/dify.git
synced 2026-01-31 17:07:01 +08:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 43741ad5d1 | |||
| 8dec406161 | |||
| 58f8d74591 | |||
| 867fc61b12 | |||
| 8e2e477a7f | |||
| 9b34f5a9ff | |||
| 5e34f938c1 | |||
| 2fd56cb01c | |||
| 4f0e272549 | |||
| 1a5279a3ef | |||
| 7775f5785f | |||
| 2de73991ff | |||
| 354d033e60 | |||
| ebc2cdad2e | |||
| 5bb841935e | |||
| 65fd4b39ce | |||
| 96d2de2258 | |||
| a71f2863ac | |||
| a9b942981d | |||
| 4b1ba2ec21 | |||
| c09184fd94 | |||
| b0d8d196e1 | |||
| 7c43123956 | |||
| eede84eb9e | |||
| b5b20234e9 | |||
| 5beb298e47 | |||
| 6b499b9a16 | |||
| 4c639961f5 | |||
| dfd3f507fb | |||
| d5695b3170 | |||
| 994fceece3 | |||
| 8c451eb0e6 | |||
| 79b4366203 | |||
| 3675d2eae8 | |||
| 38b55d2186 |
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -1,11 +1,18 @@
|
||||
name: "🕷️ Bug report"
|
||||
description: Report errors or unexpected behavior [please use English :)]
|
||||
description: Report errors or unexpected behavior
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@ -1,8 +1,16 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation [please use English :)]
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of requested docs changes
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -1,8 +1,17 @@
|
||||
name: "⭐ Feature or enhancement request"
|
||||
description: Propose something new. [please use English :)]
|
||||
description: Propose something new.
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description of the new feature / enhancement
|
||||
|
||||
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :)]
|
||||
labels:
|
||||
- help-wanted
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of the help you need
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :)]
|
||||
labels:
|
||||
- translation
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
|
||||
@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
@ -119,16 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
# Stripe configuration
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Billing configuration
|
||||
BILLING_API_URL=http://127.0.0.1:8000/v1
|
||||
BILLING_API_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_BILLING_SECRET=
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
@ -20,7 +20,7 @@ from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
|
||||
ext_database, ext_storage, ext_mail, ext_code_based_extension
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@ -96,7 +96,6 @@ def initialize_extensions(app):
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
|
||||
@ -28,7 +28,7 @@ from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -766,3 +790,4 @@ def register_commands(app):
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import dotenv
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@ -44,15 +41,11 @@ DEFAULTS = {
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
@ -61,7 +54,8 @@ DEFAULTS = {
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
|
||||
'INVITE_EXPIRY_HOURS': 72
|
||||
'INVITE_EXPIRY_HOURS': 72,
|
||||
'ETL_TYPE': 'dify',
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +85,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.33"
|
||||
self.CURRENT_VERSION = "0.3.34"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -268,8 +262,6 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@ -281,14 +273,13 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.ETL_TYPE = get_env('ETL_TYPE')
|
||||
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
@ -302,6 +293,3 @@ class CloudEditionConfig(Config):
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
||||
from . import extension, setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
# Import webhook controllers
|
||||
from .webhook import stripe
|
||||
|
||||
from .billing import billing
|
||||
|
||||
290
api/controllers/console/app/annotation.py
Normal file
290
api/controllers/console/app/annotation.py
Normal file
@ -0,0 +1,290 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, marshal
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
|
||||
annotation_hit_history_fields
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from flask import request
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
if action == 'enable':
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == 'disable':
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
response = {
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
||||
@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
@ -6,22 +6,23 @@ from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
|
||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('content', type=str, location='json')
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
|
||||
message_id = str(args['message_id'])
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
|
||||
if annotation:
|
||||
annotation.content = args['content']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['content'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return annotation
|
||||
|
||||
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
|
||||
@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
|
||||
"""Modify app model config"""
|
||||
app_id = str(app_id)
|
||||
|
||||
app_model = _get_app(app_id)
|
||||
app = _get_app(app_id)
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
mode=app.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
import stripe
|
||||
import os
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user
|
||||
from flask import current_app, request
|
||||
from flask import current_app
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
@ -40,7 +37,12 @@ class Subscription(Resource):
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
current_user.email,
|
||||
current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
@ -50,36 +52,10 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
|
||||
|
||||
class StripeBillingWebhook(Resource):
|
||||
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
BillingService.process_event(event)
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(BillingInfo, '/billing/info')
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')
|
||||
|
||||
@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
|
||||
return {'content': text}
|
||||
|
||||
|
||||
class FileeSupportTypApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
api.add_resource(FileeSupportTypApi, '/files/support-type')
|
||||
|
||||
@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
|
||||
|
||||
class StripeWebhookApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
# Handle the checkout.session.completed event
|
||||
if event['type'] == 'checkout.session.completed':
|
||||
logging.debug(event['data']['object']['id'])
|
||||
logging.debug(event['data']['object']['amount_subtotal'])
|
||||
logging.debug(event['data']['object']['currency'])
|
||||
logging.debug(event['data']['object']['payment_intent'])
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
session = stripe.checkout.Session.retrieve(
|
||||
event['data']['object']['id'],
|
||||
expand=['line_items'],
|
||||
)
|
||||
|
||||
logging.debug(session.line_items['data'][0]['quantity'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
||||
except Exception as e:
|
||||
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
||||
@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
from services.provider_service import ProviderService
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
provider_service = ProviderCheckoutService()
|
||||
provider_checkout = provider_service.create_checkout(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
account=current_user
|
||||
)
|
||||
if provider_name != 'anthropic':
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
|
||||
return {
|
||||
'url': provider_checkout.get_checkout_url()
|
||||
}
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
|
||||
@ -10,12 +10,15 @@ from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
import services
|
||||
from services.account_service import TenantService
|
||||
from services.workspace_service import WorkspaceService
|
||||
from services.file_service import FileService
|
||||
|
||||
provider_fields = {
|
||||
'provider_name': fields.String,
|
||||
@ -34,6 +37,7 @@ tenant_fields = {
|
||||
'providers': fields.List(fields.Nested(provider_fields)),
|
||||
'in_trial': fields.Boolean,
|
||||
'trial_end_reason': fields.String,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
@ -130,6 +134,61 @@ class SwitchWorkspaceApi(Resource):
|
||||
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
|
||||
|
||||
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||
|
||||
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('remove_webapp_brand', type=bool, location='json')
|
||||
parser.add_argument('replace_webapp_logo', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
custom_config_dict = {
|
||||
'remove_webapp_brand': args['remove_webapp_brand'],
|
||||
'replace_webapp_logo': args['replace_webapp_logo'],
|
||||
}
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
|
||||
tenant.custom_config_dict = custom_config_dict
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
class WebappLogoWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ['svg', 'png']:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user, True)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return { 'id': upload_file.id }, 201
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
@ -137,3 +196,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
|
||||
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
|
||||
|
||||
@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
annotation_quota_limit = billing_info['annotation_quota_limit']
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
abort(403, error_msg)
|
||||
@ -62,6 +63,10 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'workspace_custom' and not billing_info['can_replace_logo']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']:
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
custom_config = TenantService.get_custom_config(workspace_id)
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
webapp_logo_file_id,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
|
||||
@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class AppSiteApi(WebApiResource):
|
||||
@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
|
||||
'site': fields.Nested(site_fields),
|
||||
'model_config': fields.Nested(model_config_fields, allow_null=True),
|
||||
'plan': fields.String,
|
||||
'can_replace_logo': fields.Boolean,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
@marshal_with(app_fields)
|
||||
@ -50,7 +56,14 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
|
||||
edition = os.environ.get('EDITION')
|
||||
can_replace_logo = False
|
||||
|
||||
if edition == 'CLOUD':
|
||||
info = BillingService.get_info(app_model.tenant_id)
|
||||
can_replace_logo = info['can_replace_logo']
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
|
||||
api.add_resource(AppSiteApi, '/site')
|
||||
@ -59,7 +72,7 @@ api.add_resource(AppSiteApi, '/site')
|
||||
class AppSiteInfo:
|
||||
"""Class to store site information."""
|
||||
|
||||
def __init__(self, tenant, app, site, end_user):
|
||||
def __init__(self, tenant, app, site, end_user, can_replace_logo):
|
||||
"""Initialize AppSiteInfo instance."""
|
||||
self.app_id = app.id
|
||||
self.end_user_id = end_user
|
||||
@ -67,6 +80,16 @@ class AppSiteInfo:
|
||||
self.site = site
|
||||
self.model_config = None
|
||||
self.plan = tenant.plan
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
'remove_webapp_brand': remove_webapp_brand,
|
||||
'replace_webapp_logo': replace_webapp_logo,
|
||||
}
|
||||
|
||||
if app.enable_site and site.prompt_public:
|
||||
app_model_config = app.app_model_config
|
||||
|
||||
@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
class Completion:
|
||||
@ -33,7 +38,7 @@ class Completion:
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@ -109,7 +114,10 @@ class Completion:
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# check annotation reply
|
||||
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
|
||||
if annotation_reply:
|
||||
return
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
@ -166,17 +174,18 @@ class Completion:
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
|
||||
query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation = ModerationFactory(type, app_id, tenant_id,
|
||||
app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
@ -324,6 +333,81 @@ class Completion:
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
|
||||
from_source: str) -> bool:
|
||||
"""Get memory messages."""
|
||||
app_model_config = conversation_message_task.app_model_config
|
||||
app = conversation_message_task.app
|
||||
annotation_reply = app_model_config.annotation_reply_dict
|
||||
if annotation_reply['enabled']:
|
||||
try:
|
||||
score_threshold = annotation_reply.get('score_threshold', 1)
|
||||
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
|
||||
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_provider_name=embedding_provider_name,
|
||||
model_name=embedding_model_name
|
||||
)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
conversation_message_task.query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
conversation_message_task.query,
|
||||
conversation_message_task.user.id,
|
||||
conversation_message_task.message.id,
|
||||
from_source,
|
||||
score)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
return False
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
|
||||
@ -319,6 +319,10 @@ class ConversationMessageTask:
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
|
||||
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
||||
@ -435,7 +439,7 @@ class PubHandler:
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
'conversation_id': self._conversation.id,
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
@ -446,6 +450,30 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
|
||||
content = {
|
||||
'event': 'annotation',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id,
|
||||
'text': text,
|
||||
'annotation_id': annotation_id,
|
||||
'annotation_author_name': annotation_author_name
|
||||
}
|
||||
}
|
||||
self._message.answer = text
|
||||
self._message.provider_response_latency = time.perf_counter() - start_at
|
||||
|
||||
db.session.commit()
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
|
||||
@ -3,7 +3,8 @@ from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
@ -49,14 +57,34 @@ class FileExtractor:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
if is_automatic:
|
||||
loader = UnstructuredFileLoader(
|
||||
file_path, strategy="hi_res", mode="elements"
|
||||
)
|
||||
# loader = UnstructuredAPIFileLoader(
|
||||
# file_path=filenames[0],
|
||||
# api_key="FAKE_API_KEY",
|
||||
# )
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
|
||||
41
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
41
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownLoader(BaseLoader):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredTextLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_xml.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_xml.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -10,7 +10,7 @@ from flask import current_app
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
|
||||
@ -32,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
pass
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
return KeywordTableRetriever(index=self, **kwargs)
|
||||
|
||||
|
||||
@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
@ -121,6 +120,16 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
||||
@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
||||
@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
@ -33,7 +38,8 @@ class VectorIndex:
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings
|
||||
embeddings=embeddings,
|
||||
attributes=attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
attributes = self._attributes
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
@ -141,6 +142,27 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
|
||||
@ -397,7 +397,7 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=False)
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=True)
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
@ -632,8 +632,8 @@ class IndexingRunner:
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
{
|
||||
|
||||
@ -23,7 +23,8 @@ FUNCTION_CALL_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k'
|
||||
'gpt-35-turbo-16k',
|
||||
'gpt-4-1106-preview'
|
||||
]
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
|
||||
@ -24,7 +24,10 @@ class CohereReranking(BaseReranking):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> \
|
||||
Optional[List[Document]]:
|
||||
if not documents:
|
||||
return []
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
@ -34,7 +37,7 @@ class CohereReranking(BaseReranking):
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
documents = unique_documents
|
||||
|
||||
|
||||
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ class XinferenceReranking(BaseReranking):
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
if not documents:
|
||||
return []
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
|
||||
@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.anthropic \
|
||||
and hosted_model_providers.anthropic.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
|
||||
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
|
||||
@ -122,6 +122,22 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-1106-preview',
|
||||
'name': 'gpt-4-1106-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-vision-preview',
|
||||
'name': 'gpt-4-vision-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.VISION.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
@ -171,6 +187,8 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
base_model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-4-1106-preview': 4096,
|
||||
'gpt-4-vision-preview': 4096,
|
||||
'gpt-35-turbo': 4096,
|
||||
'gpt-35-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
@ -376,6 +394,18 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4-1106-preview',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4-vision-preview',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-davinci-003',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
|
||||
@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
).update({'last_used': datetime.utcnow()})
|
||||
db.session.commit()
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
||||
"""
|
||||
get provider model.
|
||||
|
||||
@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the openai hosted model. -1 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedAzureOpenAI(BaseModel):
|
||||
@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1000000
|
||||
paid_min_quantity: int = 20
|
||||
paid_max_quantity: int = 100
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
@ -68,8 +62,6 @@ def init_app(app: Flask):
|
||||
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
@ -85,10 +77,6 @@ def init_app(app: Flask):
|
||||
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
|
||||
@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get payment info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
|
||||
@ -21,6 +21,18 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.0015",
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
@ -82,7 +82,8 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
|
||||
@ -158,7 +158,8 @@ class DatasetRetrieverTool(BaseTool):
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
||||
|
||||
@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_doc_ids(self, doc_ids: list):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["doc_id"] in {doc_ids}',
|
||||
|
||||
@ -243,7 +243,7 @@ class Weaviate(VectorStore):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
properties = ['text', 'dataset_id', 'doc_hash', 'doc_id', 'document_id']
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
@ -6,13 +6,13 @@ from models.model import AppModelConfig
|
||||
|
||||
@app_model_config_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
app_model = sender
|
||||
app = sender
|
||||
app_model_config = kwargs.get('app_model_config')
|
||||
|
||||
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
||||
|
||||
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
|
||||
AppDatasetJoin.app_id == app_model.id
|
||||
AppDatasetJoin.app_id == app.id
|
||||
).all()
|
||||
|
||||
removed_dataset_ids = []
|
||||
@ -29,14 +29,14 @@ def handle(sender, **kwargs):
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).filter(
|
||||
AppDatasetJoin.app_id == app_model.id,
|
||||
AppDatasetJoin.app_id == app.id,
|
||||
AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
app_dataset_join = AppDatasetJoin(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
dataset_id=dataset_id
|
||||
)
|
||||
db.session.add(app_dataset_join)
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
import stripe
|
||||
|
||||
|
||||
def init_app(app):
|
||||
if app.config.get('STRIPE_API_KEY'):
|
||||
stripe.api_key = app.config.get('STRIPE_API_KEY')
|
||||
36
api/fields/annotation_fields.py
Normal file
36
api/fields/annotation_fields.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask_restful import fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"answer": fields.Raw(attribute='content'),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(account_fields, allow_null=True)
|
||||
}
|
||||
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"id": fields.String,
|
||||
"source": fields.String,
|
||||
"score": fields.Float,
|
||||
"question": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"match": fields.String(attribute='annotation_question'),
|
||||
"response": fields.String(attribute='annotation_content')
|
||||
}
|
||||
|
||||
annotation_hit_history_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
|
||||
}
|
||||
@ -21,6 +21,7 @@ model_config_fields = {
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
|
||||
|
||||
@ -23,11 +23,18 @@ feedback_fields = {
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'id': fields.String,
|
||||
'question': fields.String,
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
'annotation_id': fields.String,
|
||||
'annotation_create_account': fields.Nested(account_fields, allow_null=True)
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
@ -49,6 +56,7 @@ message_detail_fields = {
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
|
||||
'created_at': TimestampField,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
}
|
||||
|
||||
@ -0,0 +1,50 @@
|
||||
"""add_app_anntation_setting
|
||||
|
||||
Revision ID: 246ba09cbbdb
|
||||
Revises: 714aafe25d39
|
||||
Create Date: 2023-12-14 11:26:12.287264
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '246ba09cbbdb'
|
||||
down_revision = '714aafe25d39'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_annotation_settings',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('collection_binding_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey')
|
||||
)
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_reply')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_settings_app_idx')
|
||||
|
||||
op.drop_table('app_annotation_settings')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,32 @@
|
||||
"""add-annotation-histoiry-score
|
||||
|
||||
Revision ID: 46976cc39132
|
||||
Revises: e1901f623fd0
|
||||
Create Date: 2023-12-13 04:39:59.302971
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '46976cc39132'
|
||||
down_revision = 'e1901f623fd0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_column('score')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,34 @@
|
||||
"""add_anntation_history_match_response
|
||||
|
||||
Revision ID: 714aafe25d39
|
||||
Revises: f2a6fc85e260
|
||||
Create Date: 2023-12-14 06:38:02.972527
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '714aafe25d39'
|
||||
down_revision = 'f2a6fc85e260'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
|
||||
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_content')
|
||||
batch_op.drop_column('annotation_question')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,32 @@
|
||||
"""add custom config in tenant
|
||||
|
||||
Revision ID: 88072f0caa04
|
||||
Revises: fca025d3b60f
|
||||
Create Date: 2023-12-14 07:36:50.705362
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '88072f0caa04'
|
||||
down_revision = '246ba09cbbdb'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.drop_column('custom_config')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
79
api/migrations/versions/e1901f623fd0_add_annotation_reply.py
Normal file
79
api/migrations/versions/e1901f623fd0_add_annotation_reply.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""add-annotation-reply
|
||||
|
||||
Revision ID: e1901f623fd0
|
||||
Revises: fca025d3b60f
|
||||
Create Date: 2023-12-12 06:58:41.054544
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e1901f623fd0'
|
||||
down_revision = 'fca025d3b60f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_annotation_hit_histories',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('annotation_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('source', sa.Text(), nullable=False),
|
||||
sa.Column('question', sa.Text(), nullable=False),
|
||||
sa.Column('account_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey')
|
||||
)
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False)
|
||||
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
|
||||
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False))
|
||||
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.drop_column('hit_count')
|
||||
batch_op.drop_column('question')
|
||||
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_column('type')
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_reply')
|
||||
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_hit_histories_app_idx')
|
||||
batch_op.drop_index('app_annotation_hit_histories_annotation_idx')
|
||||
batch_op.drop_index('app_annotation_hit_histories_account_idx')
|
||||
|
||||
op.drop_table('app_annotation_hit_histories')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,34 @@
|
||||
"""add_anntation_history_message_id
|
||||
|
||||
Revision ID: f2a6fc85e260
|
||||
Revises: 46976cc39132
|
||||
Create Date: 2023-12-13 11:09:29.329584
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f2a6fc85e260'
|
||||
down_revision = '46976cc39132'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
|
||||
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_hit_histories_message_idx')
|
||||
batch_op.drop_column('message_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import enum
|
||||
from math import e
|
||||
from typing import List
|
||||
|
||||
from flask_login import UserMixin
|
||||
@ -112,6 +114,7 @@ class Tenant(db.Model):
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
custom_config = db.Column(db.Text)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -121,6 +124,14 @@ class Tenant(db.Model):
|
||||
Account.id == TenantAccountJoin.account_id,
|
||||
TenantAccountJoin.tenant_id == self.id
|
||||
).all()
|
||||
|
||||
@property
|
||||
def custom_config_dict(self) -> dict:
|
||||
return json.loads(self.custom_config) if self.custom_config else {}
|
||||
|
||||
@custom_config_dict.setter
|
||||
def custom_config_dict(self, value: dict):
|
||||
self.custom_config = json.dumps(value)
|
||||
|
||||
|
||||
class TenantAccountJoinRole(enum.Enum):
|
||||
|
||||
@ -135,7 +135,7 @@ class DatasetProcessRule(db.Model):
|
||||
],
|
||||
'segmentation': {
|
||||
'delimiter': '\n',
|
||||
'max_tokens': 512
|
||||
'max_tokens': 1000
|
||||
}
|
||||
}
|
||||
|
||||
@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model):
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = db.Column(db.String(64), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
@ -128,6 +129,25 @@ class AppModelConfig(db.Model):
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource \
|
||||
else {"enabled": False}
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict:
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == self.app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
||||
|
||||
else:
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
@ -170,7 +190,9 @@ class AppModelConfig(db.Model):
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
|
||||
return json.loads(self.file_upload) if self.file_upload else {
|
||||
"image": {"enabled": False, "number_limits": 3, "detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"]}}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@ -182,6 +204,7 @@ class AppModelConfig(db.Model):
|
||||
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
|
||||
"speech_to_text": self.speech_to_text_dict,
|
||||
"retriever_resource": self.retriever_resource_dict,
|
||||
"annotation_reply": self.annotation_reply_dict,
|
||||
"more_like_this": self.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
|
||||
"external_data_tools": self.external_data_tools_list,
|
||||
@ -504,6 +527,12 @@ class Message(db.Model):
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.message_id == self.id).first())
|
||||
return annotation_history
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
|
||||
@ -616,9 +645,11 @@ class MessageAnnotation(db.Model):
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
message_id = db.Column(UUID, nullable=True)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -629,6 +660,79 @@ class MessageAnnotation(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(db.Model):
|
||||
__tablename__ = 'app_annotation_hit_histories'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'),
|
||||
db.Index('app_annotation_hit_histories_app_idx', 'app_id'),
|
||||
db.Index('app_annotation_hit_histories_account_idx', 'account_id'),
|
||||
db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'),
|
||||
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
annotation_id = db.Column(UUID, nullable=False)
|
||||
source = db.Column(db.Text, nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
score = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
annotation_question = db.Column(db.Text, nullable=False)
|
||||
annotation_content = db.Column(db.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.filter(MessageAnnotation.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.account_id).first()
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(db.Model):
|
||||
__tablename__ = 'app_annotation_settings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'),
|
||||
db.Index('app_annotation_settings_app_idx', 'app_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
collection_binding_id = db.Column(UUID, nullable=False)
|
||||
created_user_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_user_id = db.Column(UUID, nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def created_account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def updated_account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
collection_binding_detail = (db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == self.collection_binding_id).first())
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(db.Model):
|
||||
__tablename__ = 'operation_logs'
|
||||
__table_args__ = (
|
||||
|
||||
@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class ProviderOrderPaymentStatus(Enum):
|
||||
WAIT_PAY = 'wait_pay'
|
||||
PAID = 'paid'
|
||||
PAY_FAILED = 'pay_failed'
|
||||
REFUNDED = 'refunded'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderOrderPaymentStatus:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
|
||||
class ProviderOrder(db.Model):
|
||||
__tablename__ = 'provider_orders'
|
||||
__table_args__ = (
|
||||
|
||||
@ -46,7 +46,6 @@ websocket-client~=1.6.1
|
||||
dashscope~=1.11.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference-client~=0.6.4
|
||||
safetensors==0.3.2
|
||||
@ -54,4 +53,6 @@ zhipuai==1.0.7
|
||||
werkzeug==2.3.7
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.6.4
|
||||
cohere~=4.32
|
||||
cohere~=4.32
|
||||
unstructured~=0.10.27
|
||||
unstructured[docx,pptx]~=0.10.27
|
||||
@ -412,6 +412,12 @@ class TenantService:
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> None:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404()
|
||||
|
||||
return tenant.custom_config_dict
|
||||
|
||||
|
||||
class RegisterService:
|
||||
|
||||
|
||||
426
api/services/annotation_service.py
Normal file
426
api/services/annotation_service.py
Normal file
@ -0,0 +1,426 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
|
||||
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if 'message_id' in args and args['message_id']:
|
||||
message_id = str(args['message_id'])
|
||||
# get message info
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
|
||||
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
|
||||
args['score_threshold'],
|
||||
args['embedding_provider_name'], args['embedding_model_name'])
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
|
||||
MessageAnnotation.content.ilike('%{}%'.format(keyword))
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
else:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc()).all())
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
update_annotation_to_index_task.delay(annotation.id, annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id, app_annotation_setting.collection_binding_id)
|
||||
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(annotation.id, app_id,
|
||||
current_user.current_tenant_id,
|
||||
app_annotation_setting.collection_binding_id)
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
content = {
|
||||
'question': row[0],
|
||||
'answer': row[1]
|
||||
}
|
||||
result.append(content)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {
|
||||
'error_msg': str(e)
|
||||
}
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
return None
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
|
||||
annotation_content: str, query: str, user_id: str,
|
||||
message_id: str, from_source: str, score: float):
|
||||
# add hit count to annotation
|
||||
db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.id == annotation_id
|
||||
).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
annotation_hit_history = AppAnnotationHitHistory(
|
||||
annotation_id=annotation_id,
|
||||
app_id=app_id,
|
||||
account_id=user_id,
|
||||
question=query,
|
||||
source=from_source,
|
||||
score=score,
|
||||
message_id=message_id,
|
||||
annotation_question=annotation_question,
|
||||
annotation_content=annotation_content
|
||||
)
|
||||
db.session.add(annotation_hit_history)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
||||
return {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
).first()
|
||||
if not annotation_setting:
|
||||
raise NotFound("App annotation not found")
|
||||
annotation_setting.score_threshold = args['score_threshold']
|
||||
annotation_setting.updated_user_id = current_user.id
|
||||
annotation_setting.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
||||
@ -138,7 +138,22 @@ class AppModelConfigService:
|
||||
config["retriever_resource"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
# annotation reply
|
||||
if 'annotation_reply' not in config or not config["annotation_reply"]:
|
||||
config["annotation_reply"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["annotation_reply"], dict):
|
||||
raise ValueError("annotation_reply must be of dict type")
|
||||
|
||||
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
|
||||
config["annotation_reply"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["annotation_reply"]["enabled"], bool):
|
||||
raise ValueError("enabled in annotation_reply must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
@ -325,6 +340,7 @@ class AppModelConfigService:
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"annotation_reply": config["annotation_reply"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantAccountJoin
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
@ -10,7 +14,7 @@ class BillingService:
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {'tenant_id': tenant_id}
|
||||
|
||||
billing_info = cls._send_request('GET', '/info', params=params)
|
||||
billing_info = cls._send_request('GET', '/subscription/info', params=params)
|
||||
|
||||
return billing_info
|
||||
|
||||
@ -18,16 +22,26 @@ class BillingService:
|
||||
def get_subscription(cls, plan: str,
|
||||
interval: str,
|
||||
prefilled_email: str = '',
|
||||
user_name: str = '',
|
||||
tenant_id: str = ''):
|
||||
params = {
|
||||
'plan': plan,
|
||||
'interval': interval,
|
||||
'prefilled_email': prefilled_email,
|
||||
'user_name': user_name,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/subscription', params=params)
|
||||
return cls._send_request('GET', '/subscription/payment-link', params=params)
|
||||
|
||||
@classmethod
|
||||
def get_model_provider_payment_link(cls,
|
||||
provider_name: str,
|
||||
tenant_id: str,
|
||||
account_id: str):
|
||||
params = {
|
||||
'provider_name': provider_name,
|
||||
'tenant_id': tenant_id,
|
||||
'account_id': account_id
|
||||
}
|
||||
return cls._send_request('GET', '/model-provider/payment-link', params=params)
|
||||
|
||||
@classmethod
|
||||
def get_invoices(cls, prefilled_email: str = ''):
|
||||
@ -46,9 +60,14 @@ class BillingService:
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def process_event(cls, event: dict):
|
||||
json = {
|
||||
"content": event,
|
||||
}
|
||||
return cls._send_request('POST', '/webhook/stripe', json=json)
|
||||
@staticmethod
|
||||
def is_tenant_owner(current_user):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
|
||||
if join.role != 'owner':
|
||||
raise ValueError('Only tenant owner can perform this action')
|
||||
|
||||
@ -165,7 +165,8 @@ class CompletionService:
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name
|
||||
'auto_generate_name': auto_generate_name,
|
||||
'from_source': from_source
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
@ -193,7 +194,7 @@ class CompletionService:
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True):
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
@ -218,7 +219,8 @@ class CompletionService:
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name
|
||||
auto_generate_name=auto_generate_name,
|
||||
from_source=from_source
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
@ -385,6 +387,9 @@ class CompletionService:
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
if result['event'] == 'annotation' and 'data' in result:
|
||||
message_result['annotation'] = result.get('data')
|
||||
return cls.get_blocking_annotation_message_response_data(message_result)
|
||||
if result['event'] == 'message' and 'data' in result:
|
||||
message_result['message'] = result.get('data')
|
||||
if result['event'] == 'message_end' and 'data' in result:
|
||||
@ -427,6 +432,9 @@ class CompletionService:
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'annotation':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_end':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_end_data(result.get('data'))) + "\n\n"
|
||||
@ -499,6 +507,25 @@ class CompletionService:
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_annotation_message_response_data(cls, data: dict):
|
||||
message = data.get('annotation')
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': message.get('annotation_id'),
|
||||
'annotation_author_name': message.get('annotation_author_name')
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_end_data(cls, data: dict):
|
||||
response_data = {
|
||||
@ -551,6 +578,23 @@ class CompletionService:
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_annotation_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': data.get('annotation_id'),
|
||||
'annotation_author_name': data.get('annotation_author_name'),
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def handle_error(cls, result: dict):
|
||||
logging.debug("error: %s", result)
|
||||
|
||||
@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
from tasks.update_segment_index_task import update_segment_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
|
||||
@ -1175,10 +1172,12 @@ class SegmentService:
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
|
||||
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
|
||||
collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == provider_name,
|
||||
DatasetCollectionBinding.model_name == model_name). \
|
||||
DatasetCollectionBinding.model_name == model_name,
|
||||
DatasetCollectionBinding.type == collection_type). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
|
||||
type=collection_type
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
return dataset_collection_binding
|
||||
|
||||
@classmethod
|
||||
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
|
||||
collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == collection_binding_id,
|
||||
DatasetCollectionBinding.type == collection_type). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
return dataset_collection_binding
|
||||
|
||||
@ -17,8 +17,8 @@ from models.model import UploadFile, EndUser
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
|
||||
'jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@ -27,7 +27,13 @@ class FileService:
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
@ -154,3 +160,21 @@ class FileService:
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_public_image_preview(file_id: str) -> str:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
generator = storage.load(upload_file.key)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
@ -1,174 +0,0 @@
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import current_app
|
||||
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
|
||||
class ProviderCheckout:
|
||||
def __init__(self, stripe_checkout_session):
|
||||
self.stripe_checkout_session = stripe_checkout_session
|
||||
|
||||
def get_checkout_url(self):
|
||||
return self.stripe_checkout_session.url
|
||||
|
||||
|
||||
class ProviderCheckoutService:
|
||||
def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
|
||||
# check provider name is valid
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
if provider_name not in model_provider_rules:
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
|
||||
model_provider_rule = model_provider_rules[provider_name]
|
||||
|
||||
# check provider name can be paid
|
||||
self._check_provider_payable(provider_name, model_provider_rule)
|
||||
|
||||
# get stripe checkout product id
|
||||
paid_provider = self._get_paid_provider(tenant_id, provider_name)
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
|
||||
model_provider = model_provider_class(provider=paid_provider)
|
||||
payment_info = model_provider.get_payment_info()
|
||||
if not payment_info:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
payment_product_id = payment_info['product_id']
|
||||
payment_min_quantity = payment_info['min_quantity']
|
||||
payment_max_quantity = payment_info['max_quantity']
|
||||
|
||||
# create provider order
|
||||
provider_order = ProviderOrder(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
account_id=account.id,
|
||||
payment_product_id=payment_product_id,
|
||||
quantity=1,
|
||||
payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
|
||||
)
|
||||
|
||||
db.session.add(provider_order)
|
||||
db.session.flush()
|
||||
|
||||
line_item = {
|
||||
'price': f'{payment_product_id}',
|
||||
'quantity': payment_min_quantity
|
||||
}
|
||||
|
||||
if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
|
||||
line_item['adjustable_quantity'] = {
|
||||
'enabled': True,
|
||||
'minimum': payment_min_quantity,
|
||||
'maximum': payment_max_quantity
|
||||
}
|
||||
|
||||
try:
|
||||
# create stripe checkout session
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
line_items=[
|
||||
line_item
|
||||
],
|
||||
mode='payment',
|
||||
success_url=current_app.config.get("CONSOLE_WEB_URL")
|
||||
+ f'?provider_name={provider_name}&payment_result=succeeded',
|
||||
cancel_url=current_app.config.get("CONSOLE_WEB_URL")
|
||||
+ f'?provider_name={provider_name}&payment_result=cancelled',
|
||||
automatic_tax={'enabled': True},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')
|
||||
|
||||
provider_order.payment_id = checkout_session.id
|
||||
db.session.commit()
|
||||
|
||||
return ProviderCheckout(checkout_session)
|
||||
|
||||
def fulfill_provider_order(self, event, line_items):
|
||||
provider_order = db.session.query(ProviderOrder) \
|
||||
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
|
||||
.first()
|
||||
|
||||
if not provider_order:
|
||||
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
|
||||
|
||||
if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
|
||||
raise ValueError(
|
||||
f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
|
||||
|
||||
provider_order.transaction_id = event['data']['object']['payment_intent']
|
||||
provider_order.currency = event['data']['object']['currency']
|
||||
provider_order.total_amount = event['data']['object']['amount_subtotal']
|
||||
provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
|
||||
provider_order.paid_at = datetime.datetime.utcnow()
|
||||
provider_order.updated_at = provider_order.paid_at
|
||||
|
||||
# update provider quota
|
||||
provider = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == provider_order.tenant_id,
|
||||
Provider.provider_name == provider_order.provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.PAID.value
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
|
||||
f'provider name: {provider_order.provider_name}')
|
||||
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
|
||||
model_provider = model_provider_class(provider=provider)
|
||||
payment_info = model_provider.get_payment_info()
|
||||
|
||||
quantity = line_items['data'][0]['quantity']
|
||||
|
||||
if not payment_info:
|
||||
increase_quota = 0
|
||||
else:
|
||||
increase_quota = int(payment_info['increase_quota']) * quantity
|
||||
|
||||
if increase_quota > 0:
|
||||
provider.quota_limit += increase_quota
|
||||
provider.is_valid = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
|
||||
if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'system_config' not in model_provider_rule:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'supported_quota_types' not in model_provider_rule['system_config']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
|
||||
raise ValueError(f'provider name {provider_name} not support payment')
|
||||
|
||||
def _get_paid_provider(self, tenant_id: str, provider_name: str):
|
||||
paid_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.PAID.value,
|
||||
).first()
|
||||
|
||||
if not paid_provider:
|
||||
paid_provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.PAID.value,
|
||||
quota_limit=0,
|
||||
quota_used=0,
|
||||
)
|
||||
db.session.add(paid_provider)
|
||||
db.session.commit()
|
||||
|
||||
return paid_provider
|
||||
@ -1,8 +1,12 @@
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
|
||||
from models.provider import Provider
|
||||
|
||||
from services.billing_service import BillingService
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class WorkspaceService:
|
||||
@classmethod
|
||||
@ -28,6 +32,13 @@ class WorkspaceService:
|
||||
).first()
|
||||
tenant_info['role'] = tenant_account_join.role
|
||||
|
||||
edition = current_app.config['EDITION']
|
||||
if edition == 'CLOUD':
|
||||
billing_info = BillingService.get_info(tenant_info['id'])
|
||||
|
||||
if billing_info['can_replace_logo'] and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
|
||||
tenant_info['custom_config'] = tenant.custom_config_dict
|
||||
|
||||
# Get providers
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id
|
||||
|
||||
61
api/tasks/annotation/add_annotation_to_index_task.py
Normal file
61
api/tasks/annotation/add_annotation_to_index_task.py
Normal file
@ -0,0 +1,61 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question,
|
||||
metadata={
|
||||
"annotation_id": annotation_id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation_id
|
||||
}
|
||||
)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
||||
99
api/tasks/annotation/batch_import_annotations_task.py
Normal file
99
api/tasks/annotation/batch_import_annotations_task.py
Normal file
@ -0,0 +1,99 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
|
||||
user_id: str):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param job_id: job_id
|
||||
:param content_list: content list
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param user_id: user_id
|
||||
|
||||
"""
|
||||
logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if app:
|
||||
try:
|
||||
documents = []
|
||||
for content in content_list:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=content['answer'],
|
||||
question=content['question'],
|
||||
account_id=user_id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.flush()
|
||||
|
||||
document = Document(
|
||||
page_content=content['question'],
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
# if annotation reply is enabled , batch add annotations' index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
app_annotation_setting.collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise NotFound("App annotation setting not found")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
redis_client.setex(indexing_cache_key, 600, 'error')
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logging.exception("Build index for batch import annotations failed")
|
||||
45
api/tasks/annotation/delete_annotation_index_task.py
Normal file
45
api/tasks/annotation/delete_annotation_index_task.py
Normal file
@ -0,0 +1,45 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from core.index.index import IndexBuilder
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Async delete annotation index task
|
||||
"""
|
||||
logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
if vector_index:
|
||||
try:
|
||||
vector_index.delete_by_metadata_field('annotation_id', annotation_id)
|
||||
except Exception:
|
||||
logging.exception("Delete annotation index failed when annotation deleted.")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation deleted index failed:{}".format(str(e)))
|
||||
|
||||
74
api/tasks/annotation/disable_annotation_reply_task.py
Normal file
74
api/tasks/annotation/disable_annotation_reply_task.py
Normal file
@ -0,0 +1,74 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
raise NotFound("App annotation setting not found")
|
||||
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
|
||||
try:
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
if vector_index:
|
||||
try:
|
||||
vector_index.delete_by_metadata_field('app_id', app_id)
|
||||
except Exception:
|
||||
logging.exception("Delete doc index failed when dataset deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
|
||||
|
||||
# delete annotation setting
|
||||
db.session.delete(app_annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation batch deleted index failed:{}".format(str(e)))
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, 'error')
|
||||
disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id))
|
||||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
||||
106
api/tasks/annotation/enable_annotation_reply_task.py
Normal file
106
api/tasks/annotation/enable_annotation_reply_task.py
Normal file
@ -0,0 +1,106 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
|
||||
embedding_provider_name: str, embedding_model_name: str):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
annotation_setting.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(annotation_setting)
|
||||
else:
|
||||
new_app_annotation_setting = AppAnnotationSetting(
|
||||
app_id=app_id,
|
||||
score_threshold=score_threshold,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
created_user_id=user_id,
|
||||
updated_user_id=user_id
|
||||
)
|
||||
db.session.add(new_app_annotation_setting)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
try:
|
||||
index.delete_by_metadata_field('app_id', app_id)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
click.style('Delete annotation index error: {}'.format(str(e)),
|
||||
fg='red'))
|
||||
index.add_texts(documents)
|
||||
db.session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation batch created index failed:{}".format(str(e)))
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, 'error')
|
||||
enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
|
||||
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
|
||||
db.session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
63
api/tasks/annotation/update_annotation_to_index_task.py
Normal file
63
api/tasks/annotation/update_annotation_to_index_task.py
Normal file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Update annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question,
|
||||
metadata={
|
||||
"annotation_id": annotation_id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation_id
|
||||
}
|
||||
)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.delete_by_metadata_field('annotation_id', annotation_id)
|
||||
index.add_texts([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
||||
@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.33
|
||||
image: langgenius/dify-api:0.3.34
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -128,7 +128,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.33
|
||||
image: langgenius/dify-api:0.3.34
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@ -196,7 +196,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.33
|
||||
image: langgenius/dify-web:0.3.34
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@ -0,0 +1,17 @@
|
||||
import React from 'react'
|
||||
import Main from '@/app/components/app/log-annotation'
|
||||
import { PageType } from '@/app/components/app/configuration/toolbox/annotation/type'
|
||||
|
||||
export type IProps = {
|
||||
params: { appId: string }
|
||||
}
|
||||
|
||||
const Logs = async ({
|
||||
params: { appId },
|
||||
}: IProps) => {
|
||||
return (
|
||||
<Main pageType={PageType.annotation} appId={appId} />
|
||||
)
|
||||
}
|
||||
|
||||
export default Logs
|
||||
@ -1,5 +1,6 @@
|
||||
import React from 'react'
|
||||
import Main from '@/app/components/app/log'
|
||||
import Main from '@/app/components/app/log-annotation'
|
||||
import { PageType } from '@/app/components/app/configuration/toolbox/annotation/type'
|
||||
|
||||
export type IProps = {
|
||||
params: { appId: string }
|
||||
@ -9,7 +10,7 @@ const Logs = async ({
|
||||
params: { appId },
|
||||
}: IProps) => {
|
||||
return (
|
||||
<Main appId={appId} />
|
||||
<Main pageType={PageType.log} appId={appId} />
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -2,11 +2,11 @@ import classNames from 'classnames'
|
||||
import style from '../list.module.css'
|
||||
import Apps from './Apps'
|
||||
import { getLocaleOnServer } from '@/i18n/server'
|
||||
import { useTranslation } from '@/i18n/i18next-serverside-config'
|
||||
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
|
||||
|
||||
const AppList = async () => {
|
||||
const locale = getLocaleOnServer()
|
||||
const { t } = await useTranslation(locale, 'app')
|
||||
const { t } = await translate(locale, 'app')
|
||||
|
||||
return (
|
||||
<div className='flex flex-col overflow-auto bg-gray-100 shrink-0 grow'>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import React from 'react'
|
||||
import { getLocaleOnServer } from '@/i18n/server'
|
||||
import { useTranslation } from '@/i18n/i18next-serverside-config'
|
||||
import { useTranslation as translate } from '@/i18n/i18next-serverside-config'
|
||||
import Form from '@/app/components/datasets/settings/form'
|
||||
|
||||
type Props = {
|
||||
@ -11,8 +11,7 @@ const Settings = async ({
|
||||
params: { datasetId },
|
||||
}: Props) => {
|
||||
const locale = getLocaleOnServer()
|
||||
// eslint-disable-next-line react-hooks/rules-of-hooks
|
||||
const { t } = await useTranslation(locale, 'dataset-settings')
|
||||
const { t } = await translate(locale, 'dataset-settings')
|
||||
|
||||
return (
|
||||
<div className='bg-white h-full overflow-y-auto'>
|
||||
|
||||
@ -28,7 +28,15 @@ export default function NavLink({
|
||||
mode = 'expand',
|
||||
}: NavLinkProps) {
|
||||
const segment = useSelectedLayoutSegment()
|
||||
const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase()
|
||||
const formattedSegment = (() => {
|
||||
let res = segment?.toLowerCase()
|
||||
// logs and annotations use the same nav
|
||||
if (res === 'annotations')
|
||||
res = 'logs'
|
||||
|
||||
return res
|
||||
})()
|
||||
const isActive = href.toLowerCase().split('/')?.pop() === formattedSegment
|
||||
const NavIcon = isActive ? iconMap.selected : iconMap.normal
|
||||
|
||||
return (
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Textarea from 'rc-textarea'
|
||||
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
|
||||
|
||||
export enum EditItemType {
|
||||
Query = 'query',
|
||||
Answer = 'answer',
|
||||
}
|
||||
type Props = {
|
||||
type: EditItemType
|
||||
content: string
|
||||
onChange: (content: string) => void
|
||||
}
|
||||
|
||||
const EditItem: FC<Props> = ({
|
||||
type,
|
||||
content,
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const avatar = type === EditItemType.Query ? <User className='w-6 h-6' /> : <Robot className='w-6 h-6' />
|
||||
const name = type === EditItemType.Query ? t('appAnnotation.addModal.queryName') : t('appAnnotation.addModal.answerName')
|
||||
const placeholder = type === EditItemType.Query ? t('appAnnotation.addModal.queryPlaceholder') : t('appAnnotation.addModal.answerPlaceholder')
|
||||
|
||||
return (
|
||||
<div className='flex' onClick={e => e.stopPropagation()}>
|
||||
<div className='shrink-0 mr-3'>
|
||||
{avatar}
|
||||
</div>
|
||||
<div className='grow'>
|
||||
<div className='mb-1 leading-[18px] text-xs font-semibold text-gray-900'>{name}</div>
|
||||
<Textarea
|
||||
className='mt-1 block w-full leading-5 max-h-none text-sm text-gray-700 outline-none appearance-none resize-none'
|
||||
value={content}
|
||||
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => onChange(e.target.value)}
|
||||
autoSize={{ minRows: 3 }}
|
||||
placeholder={placeholder}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(EditItem)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user