mirror of
https://github.com/langgenius/dify.git
synced 2026-01-26 23:05:45 +08:00
Compare commits
97 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 43741ad5d1 | |||
| 8dec406161 | |||
| 58f8d74591 | |||
| 867fc61b12 | |||
| 8e2e477a7f | |||
| 9b34f5a9ff | |||
| 5e34f938c1 | |||
| 2fd56cb01c | |||
| 4f0e272549 | |||
| 1a5279a3ef | |||
| 7775f5785f | |||
| 2de73991ff | |||
| 354d033e60 | |||
| ebc2cdad2e | |||
| 5bb841935e | |||
| 65fd4b39ce | |||
| 96d2de2258 | |||
| a71f2863ac | |||
| a9b942981d | |||
| 4b1ba2ec21 | |||
| c09184fd94 | |||
| b0d8d196e1 | |||
| 7c43123956 | |||
| eede84eb9e | |||
| b5b20234e9 | |||
| 5beb298e47 | |||
| 6b499b9a16 | |||
| 4c639961f5 | |||
| dfd3f507fb | |||
| d5695b3170 | |||
| 994fceece3 | |||
| 8c451eb0e6 | |||
| 79b4366203 | |||
| 3675d2eae8 | |||
| 38b55d2186 | |||
| bee0d12455 | |||
| 13f2c90a7b | |||
| a3dca3dabc | |||
| e5c7a81ce3 | |||
| 8b0100523b | |||
| 1350599c0b | |||
| bc54cdc537 | |||
| 5d10cf0fe6 | |||
| 7b8a10f3ea | |||
| cb3a55dae6 | |||
| 5789d76582 | |||
| 2e588ae221 | |||
| b5dd948e56 | |||
| 1263b7de75 | |||
| 75a6122173 | |||
| 053102f433 | |||
| d3a2c0ed34 | |||
| 8fbc374f31 | |||
| 08b7ebba91 | |||
| a1cd043fdc | |||
| 671a8e7972 | |||
| efa16dbb44 | |||
| a6241be42a | |||
| faa88aafe8 | |||
| 1b3a98425f | |||
| 22bc9ddc73 | |||
| 0423775687 | |||
| 307c170fb6 | |||
| 0e04fcc071 | |||
| 4322b17a81 | |||
| 451af66be0 | |||
| 454577c6b1 | |||
| 53be4d2712 | |||
| 3c37fd37fa | |||
| cf0ba794d7 | |||
| c21e2063fe | |||
| ad037c6615 | |||
| 7bbfac5dba | |||
| 80ddb00f10 | |||
| 74b2260ba6 | |||
| 603e55f252 | |||
| a9c1c7d239 | |||
| 3cc697832a | |||
| bb98f5756a | |||
| e1d2203371 | |||
| 93467cb363 | |||
| ea526d0822 | |||
| 0e627c920f | |||
| ea35f1dce1 | |||
| a5b80c9d1f | |||
| f704094a5f | |||
| 1f58f15bff | |||
| b930716745 | |||
| 9587479b76 | |||
| 3c0fbf3a6a | |||
| caa330c91f | |||
| 4a55d5729d | |||
| d6a6697891 | |||
| 778cfb37a2 | |||
| ce85ee3aa6 | |||
| b23de4affc | |||
| d8a7e894aa |
17
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
17
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -1,11 +1,18 @@
|
||||
name: "🕷️ Bug report"
|
||||
description: Report errors or unexpected behavior [please use English :)]
|
||||
description: Report errors or unexpected behavior
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
@ -21,8 +28,8 @@ body:
|
||||
multiple: true
|
||||
options:
|
||||
- Cloud
|
||||
- Self Hosted
|
||||
- Other (please specify in "Steps to Reproduce")
|
||||
- Self Hosted (Docker)
|
||||
- Self Hosted (Source)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@ -1,8 +1,16 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation [please use English :)]
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of requested docs changes
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -1,8 +1,17 @@
|
||||
name: "⭐ Feature or enhancement request"
|
||||
description: Propose something new. [please use English :)]
|
||||
description: Propose something new.
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description of the new feature / enhancement
|
||||
|
||||
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :)]
|
||||
labels:
|
||||
- help-wanted
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of the help you need
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :)]
|
||||
labels:
|
||||
- translation
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
|
||||
4
.github/workflows/build-api-image.yml
vendored
4
.github/workflows/build-api-image.yml
vendored
@ -34,9 +34,7 @@ jobs:
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v4
|
||||
|
||||
4
.github/workflows/build-web-image.yml
vendored
4
.github/workflows/build-web-image.yml
vendored
@ -34,9 +34,7 @@ jobs:
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v4
|
||||
|
||||
@ -20,6 +20,8 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
[v0.3.31:Surpassing the Assistants API – Dify's RAG Demonstrates an Impressive 20% Improvement.](https://dify.ai/blog/dify-ai-rag-technology-upgrade-performance-improvement-qa-accuracy)
|
||||
|
||||
**Dify** is an LLM application development platform that has already seen over **100,000** applications built on Dify.AI. It integrates the concepts of Backend as a Service and LLMOps, covering the core tech stack required for building generative AI-native applications, including a built-in RAG engine. With Dify, **you can self-deploy capabilities similar to Assistants API and GPTs based on any LLMs.**
|
||||
|
||||

|
||||
|
||||
@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
@ -119,10 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
@ -53,12 +53,3 @@
|
||||
```
|
||||
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
|
||||
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
|
||||
|
||||
8. Start frontend
|
||||
|
||||
You can start the frontend by running `npm install && npm run dev` in web/ folder, or you can use docker to start the frontend, for example:
|
||||
|
||||
```
|
||||
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5001 --name web-self-hosted langgenius/dify-web:latest
|
||||
```
|
||||
This will start a dify frontend, now you are all set, happy coding!
|
||||
@ -20,7 +20,7 @@ from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
|
||||
ext_database, ext_storage, ext_mail, ext_code_based_extension
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@ -96,7 +96,6 @@ def initialize_extensions(app):
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
|
||||
@ -28,7 +28,7 @@ from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -766,3 +790,4 @@ def register_commands(app):
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import dotenv
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@ -15,6 +12,7 @@ DEFAULTS = {
|
||||
'DB_HOST': 'localhost',
|
||||
'DB_PORT': '5432',
|
||||
'DB_DATABASE': 'dify',
|
||||
'DB_CHARSET': '',
|
||||
'REDIS_HOST': 'localhost',
|
||||
'REDIS_PORT': '6379',
|
||||
'REDIS_DB': '0',
|
||||
@ -43,25 +41,21 @@ DEFAULTS = {
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
'UPLOAD_FILE_BATCH_LIMIT': 5,
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
|
||||
'INVITE_EXPIRY_HOURS': 72
|
||||
'INVITE_EXPIRY_HOURS': 72,
|
||||
'ETL_TYPE': 'dify',
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +85,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.31"
|
||||
self.CURRENT_VERSION = "0.3.34"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -149,10 +143,12 @@ class Config:
|
||||
# ------------------------
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in
|
||||
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
|
||||
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE', 'DB_CHARSET']
|
||||
}
|
||||
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
|
||||
db_extras = f"?client_encoding={db_credentials['DB_CHARSET']}" if db_credentials['DB_CHARSET'] else ""
|
||||
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
|
||||
self.SQLALCHEMY_ENGINE_OPTIONS = {
|
||||
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
|
||||
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
|
||||
@ -240,7 +236,6 @@ class Config:
|
||||
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
|
||||
|
||||
# Dataset Configurations.
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# File upload Configurations.
|
||||
@ -267,8 +262,6 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@ -280,14 +273,13 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.ETL_TYPE = get_env('ETL_TYPE')
|
||||
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
@ -301,6 +293,3 @@ class CloudEditionConfig(Config):
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
||||
from . import extension, setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
@ -26,5 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
# Import webhook controllers
|
||||
from .webhook import stripe
|
||||
from .billing import billing
|
||||
|
||||
290
api/controllers/console/app/annotation.py
Normal file
290
api/controllers/console/app/annotation.py
Normal file
@ -0,0 +1,290 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, marshal
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
|
||||
annotation_hit_history_fields
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from flask import request
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
if action == 'enable':
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == 'disable':
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
response = {
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
||||
@ -12,7 +12,7 @@ from constants.model_template import model_templates, demo_model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
@ -57,6 +57,7 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
@ -6,22 +6,23 @@ from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
|
||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('content', type=str, location='json')
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
|
||||
message_id = str(args['message_id'])
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
|
||||
if annotation:
|
||||
annotation.content = args['content']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['content'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return annotation
|
||||
|
||||
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@ -249,7 +230,7 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
|
||||
"""Modify app model config"""
|
||||
app_id = str(app_id)
|
||||
|
||||
app_model = _get_app(app_id)
|
||||
app = _get_app(app_id)
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
mode=app.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
|
||||
|
||||
@ -62,16 +62,15 @@ class DailyConversationStatistic(Resource):
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'conversation_count': i.conversation_count
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'conversation_count': i.conversation_count
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
@ -124,16 +123,15 @@ class DailyTerminalsStatistic(Resource):
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
@ -187,18 +185,17 @@ class DailyTokenCostStatistic(Resource):
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
'total_price': i.total_price,
|
||||
'currency': 'USD'
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
'total_price': i.total_price,
|
||||
'currency': 'USD'
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
@ -256,16 +253,15 @@ LEFT JOIN conversations c on c.id=subquery.conversation_id
|
||||
GROUP BY date
|
||||
ORDER BY date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
@ -320,20 +316,19 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@ -383,16 +378,15 @@ class AverageResponseTimeStatistic(Resource):
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'latency': round(i.latency * 1000, 4)
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'latency': round(i.latency * 1000, 4)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
@ -447,16 +441,15 @@ WHERE app_id = :app_id'''
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'tps': round(i.tokens_per_second, 4)
|
||||
})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'tps': round(i.tokens_per_second, 4)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
|
||||
61
api/controllers/console/billing/billing.py
Normal file
61
api/controllers/console/billing/billing.py
Normal file
@ -0,0 +1,61 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user
|
||||
from flask import current_app
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class BillingInfo(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
edition = current_app.config['EDITION']
|
||||
if edition != 'CLOUD':
|
||||
return {"enabled": False}
|
||||
|
||||
return BillingService.get_info(current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Subscription(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
current_user.email,
|
||||
current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
|
||||
|
||||
api.add_resource(BillingInfo, '/billing/info')
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \
|
||||
InvalidMetadataError, ArchivedDocumentImmutableError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
@ -194,6 +194,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@ -252,6 +253,7 @@ class DatasetInitApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
@ -693,6 +695,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -770,14 +773,6 @@ class DocumentStatusApi(DocumentResource):
|
||||
if not document.archived:
|
||||
raise InvalidActionError('Document is not archived.')
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + 1
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
@ -856,21 +851,6 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentLimitApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""get document limit"""
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
|
||||
return {
|
||||
'documents_count': documents_count,
|
||||
'documents_limit': tenant_document_count
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -896,4 +876,3 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentLimitApi, '/datasets/limit')
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from libs.login import login_required
|
||||
@ -114,6 +114,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def patch(self, dataset_id, segment_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -200,6 +201,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -250,6 +252,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -344,6 +347,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
|
||||
return {'content': text}
|
||||
|
||||
|
||||
class FileeSupportTypApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
api.add_resource(FileeSupportTypApi, '/files/support-type')
|
||||
|
||||
@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -14,6 +14,7 @@ from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from controllers.console.wraps import cloud_edition_billing_resource_check
|
||||
|
||||
|
||||
class InstalledAppsListApi(Resource):
|
||||
@ -47,6 +48,7 @@ class InstalledAppsListApi(Resource):
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
|
||||
|
||||
@ -105,7 +105,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -104,7 +104,7 @@ class UniversalChatStopApi(UniversalChatResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
|
||||
|
||||
class StripeWebhookApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
# Handle the checkout.session.completed event
|
||||
if event['type'] == 'checkout.session.completed':
|
||||
logging.debug(event['data']['object']['id'])
|
||||
logging.debug(event['data']['object']['amount_subtotal'])
|
||||
logging.debug(event['data']['object']['currency'])
|
||||
logging.debug(event['data']['object']['payment_intent'])
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
session = stripe.checkout.Session.retrieve(
|
||||
event['data']['object']['id'],
|
||||
expand=['line_items'],
|
||||
)
|
||||
|
||||
logging.debug(session.line_items['data'][0]['quantity'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
||||
except Exception as e:
|
||||
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
||||
@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse, marshal_with, abort, fields, marsh
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, TenantAccountJoin
|
||||
@ -47,6 +47,7 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('members')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('emails', type=str, required=True, location='json', action='append')
|
||||
|
||||
@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
from services.provider_service import ProviderService
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
@ -115,7 +115,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -155,7 +155,7 @@ class ModelProviderModelUpdateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -184,7 +184,7 @@ class ModelProviderModelUpdateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
provider_service = ProviderCheckoutService()
|
||||
provider_checkout = provider_service.create_checkout(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
account=current_user
|
||||
)
|
||||
if provider_name != 'anthropic':
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
|
||||
return {
|
||||
'url': provider_checkout.get_checkout_url()
|
||||
}
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
|
||||
@ -10,12 +10,15 @@ from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
import services
|
||||
from services.account_service import TenantService
|
||||
from services.workspace_service import WorkspaceService
|
||||
from services.file_service import FileService
|
||||
|
||||
provider_fields = {
|
||||
'provider_name': fields.String,
|
||||
@ -34,6 +37,7 @@ tenant_fields = {
|
||||
'providers': fields.List(fields.Nested(provider_fields)),
|
||||
'in_trial': fields.Boolean,
|
||||
'trial_end_reason': fields.String,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
@ -130,6 +134,61 @@ class SwitchWorkspaceApi(Resource):
|
||||
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
|
||||
|
||||
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||
|
||||
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('remove_webapp_brand', type=bool, location='json')
|
||||
parser.add_argument('replace_webapp_logo', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
custom_config_dict = {
|
||||
'remove_webapp_brand': args['remove_webapp_brand'],
|
||||
'replace_webapp_logo': args['replace_webapp_logo'],
|
||||
}
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
|
||||
tenant.custom_config_dict = custom_config_dict
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
class WebappLogoWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ['svg', 'png']:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user, True)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return { 'id': upload_file.id }, 201
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
@ -137,3 +196,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
|
||||
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
|
||||
|
||||
@ -5,6 +5,7 @@ from flask import current_app, abort
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@ -41,3 +42,35 @@ def only_edition_self_hosted(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
error_msg: str = "You have reached the limit of your subscription."):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
tenant_id = current_user.current_tenant_id
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
annotation_quota_limit = billing_info['annotation_quota_limit']
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'workspace_custom' and not billing_info['can_replace_logo']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']:
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return decorated
|
||||
return interceptor
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
custom_config = TenantService.get_custom_config(workspace_id)
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
webapp_logo_file_id,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
|
||||
@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -150,7 +150,7 @@ class ChatStopApi(AppApiResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from libs.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
@ -24,6 +24,7 @@ from services.file_service import FileService
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@ -88,6 +89,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@ -147,6 +149,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
@ -212,6 +215,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
|
||||
@ -3,7 +3,7 @@ from flask_restful import reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
@ -14,6 +14,8 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
@ -144,6 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@ -11,6 +11,7 @@ from libs.login import _get_user
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, Account
|
||||
from models.model import ApiToken, App
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
def validate_app_token(view=None):
|
||||
@ -40,6 +41,33 @@ def validate_app_token(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
error_msg: str = "You have reached the limit of your subscription."):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
billing_info = BillingService.get_info(api_token.tenant_id)
|
||||
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
raise Unauthorized(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return decorated
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
|
||||
@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@ -68,7 +68,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -139,7 +139,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class AppSiteApi(WebApiResource):
|
||||
@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
|
||||
'site': fields.Nested(site_fields),
|
||||
'model_config': fields.Nested(model_config_fields, allow_null=True),
|
||||
'plan': fields.String,
|
||||
'can_replace_logo': fields.Boolean,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
@marshal_with(app_fields)
|
||||
@ -50,7 +56,14 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
|
||||
edition = os.environ.get('EDITION')
|
||||
can_replace_logo = False
|
||||
|
||||
if edition == 'CLOUD':
|
||||
info = BillingService.get_info(app_model.tenant_id)
|
||||
can_replace_logo = info['can_replace_logo']
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
|
||||
api.add_resource(AppSiteApi, '/site')
|
||||
@ -59,7 +72,7 @@ api.add_resource(AppSiteApi, '/site')
|
||||
class AppSiteInfo:
|
||||
"""Class to store site information."""
|
||||
|
||||
def __init__(self, tenant, app, site, end_user):
|
||||
def __init__(self, tenant, app, site, end_user, can_replace_logo):
|
||||
"""Initialize AppSiteInfo instance."""
|
||||
self.app_id = app.id
|
||||
self.end_user_id = end_user
|
||||
@ -67,6 +80,16 @@ class AppSiteInfo:
|
||||
self.site = site
|
||||
self.model_config = None
|
||||
self.plan = tenant.plan
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
'remove_webapp_brand': remove_webapp_brand,
|
||||
'replace_webapp_logo': replace_webapp_logo,
|
||||
}
|
||||
|
||||
if app.enable_site and site.prompt_public:
|
||||
app_model_config = app.app_model_config
|
||||
|
||||
@ -40,7 +40,7 @@ def decode_jwt_token():
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code and not site:
|
||||
if not app_code or not site:
|
||||
raise Unauthorized('Site URL is no longer valid.')
|
||||
if app_model.enable_site is False:
|
||||
raise Unauthorized('Site is disabled.')
|
||||
|
||||
@ -59,7 +59,7 @@ class AgentExecutor:
|
||||
self.configuration = configuration
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
|
||||
@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
class Completion:
|
||||
@ -33,7 +38,7 @@ class Completion:
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@ -109,7 +114,10 @@ class Completion:
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# check annotation reply
|
||||
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
|
||||
if annotation_reply:
|
||||
return
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
@ -166,17 +174,18 @@ class Completion:
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
|
||||
query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation = ModerationFactory(type, app_id, tenant_id,
|
||||
app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
@ -324,6 +333,81 @@ class Completion:
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
|
||||
from_source: str) -> bool:
|
||||
"""Get memory messages."""
|
||||
app_model_config = conversation_message_task.app_model_config
|
||||
app = conversation_message_task.app
|
||||
annotation_reply = app_model_config.annotation_reply_dict
|
||||
if annotation_reply['enabled']:
|
||||
try:
|
||||
score_threshold = annotation_reply.get('score_threshold', 1)
|
||||
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
|
||||
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_provider_name=embedding_provider_name,
|
||||
model_name=embedding_model_name
|
||||
)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
conversation_message_task.query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
conversation_message_task.query,
|
||||
conversation_message_task.user.id,
|
||||
conversation_message_task.message.id,
|
||||
from_source,
|
||||
score)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
return False
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
|
||||
@ -319,9 +319,13 @@ class ConversationMessageTask:
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
|
||||
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account | EndUser], task_id: str,
|
||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
||||
message: Message, conversation: Conversation,
|
||||
chain_pub: bool = False, agent_thought_pub: bool = False):
|
||||
self._channel = PubHandler.generate_channel_name(user, task_id)
|
||||
@ -334,7 +338,7 @@ class PubHandler:
|
||||
self._agent_thought_pub = agent_thought_pub
|
||||
|
||||
@classmethod
|
||||
def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
|
||||
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
|
||||
if not user:
|
||||
raise ValueError("user is required")
|
||||
|
||||
@ -342,7 +346,7 @@ class PubHandler:
|
||||
return "generate_result:{}-{}".format(user_str, task_id)
|
||||
|
||||
@classmethod
|
||||
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
|
||||
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||
|
||||
@ -435,7 +439,7 @@ class PubHandler:
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
'conversation_id': self._conversation.id,
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
@ -446,6 +450,30 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
|
||||
content = {
|
||||
'event': 'annotation',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id,
|
||||
'text': text,
|
||||
'annotation_id': annotation_id,
|
||||
'annotation_author_name': annotation_author_name
|
||||
}
|
||||
}
|
||||
self._message.answer = text
|
||||
self._message.provider_response_latency = time.perf_counter() - start_at
|
||||
|
||||
db.session.commit()
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
@ -454,7 +482,7 @@ class PubHandler:
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
|
||||
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
|
||||
content = {
|
||||
'error': type(e).__name__,
|
||||
'description': e.description if getattr(e, 'description', None) is not None else str(e)
|
||||
@ -467,7 +495,7 @@ class PubHandler:
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account | EndUser], task_id: str):
|
||||
def ping(cls, user: Union[Account, EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
@ -476,7 +504,7 @@ class PubHandler:
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account | EndUser], task_id: str):
|
||||
def stop(cls, user: Union[Account, EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
|
||||
@ -3,7 +3,8 @@ from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
@ -49,14 +57,34 @@ class FileExtractor:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
if is_automatic:
|
||||
loader = UnstructuredFileLoader(
|
||||
file_path, strategy="hi_res", mode="elements"
|
||||
)
|
||||
# loader = UnstructuredAPIFileLoader(
|
||||
# file_path=filenames[0],
|
||||
# api_key="FAKE_API_KEY",
|
||||
# )
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
|
||||
41
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
41
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownLoader(BaseLoader):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredTextLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_xml.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_xml.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -8,7 +8,7 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatesetDocumentStore:
|
||||
class DatasetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -20,7 +20,7 @@ class DatesetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
@ -18,31 +18,30 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
# use doc embedding cache or store if not exists
|
||||
text_embeddings = []
|
||||
embedding_queue_texts = []
|
||||
for text in texts:
|
||||
text_embeddings = [None for _ in range(len(texts))]
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
|
||||
if embedding:
|
||||
text_embeddings.append(embedding.get_embedding())
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
embedding_queue_texts.append(text)
|
||||
embedding_queue_indices.append(i)
|
||||
|
||||
if embedding_queue_texts:
|
||||
if embedding_queue_indices:
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
|
||||
embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
i = 0
|
||||
normalized_embedding_results = []
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
for i, indice in enumerate(embedding_queue_indices):
|
||||
hash = helper.generate_text_hash(texts[indice])
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
normalized_embedding_results.append(normalized_embedding)
|
||||
text_embeddings[indice] = normalized_embedding
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
@ -52,10 +51,7 @@ class CacheEmbedding(Embeddings):
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(normalized_embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@ -10,7 +10,7 @@ from flask import current_app
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
|
||||
@ -32,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
pass
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
return KeywordTableRetriever(index=self, **kwargs)
|
||||
|
||||
|
||||
@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
@ -121,6 +120,16 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
||||
@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
||||
@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
@ -33,7 +38,8 @@ class VectorIndex:
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings
|
||||
embeddings=embeddings,
|
||||
attributes=attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
attributes = self._attributes
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
@ -141,6 +142,27 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
|
||||
@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
@ -106,7 +106,8 @@ class IndexingRunner:
|
||||
document_id=dataset_document.id
|
||||
).all()
|
||||
|
||||
db.session.delete(document_segments)
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
# load file
|
||||
@ -396,7 +397,7 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=False)
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=True)
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
@ -474,7 +475,7 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
doc_store = DatesetDocumentStore(
|
||||
doc_store = DatasetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=dataset_document.created_by,
|
||||
document_id=dataset_document.id
|
||||
@ -631,8 +632,8 @@ class IndexingRunner:
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
{
|
||||
|
||||
@ -75,6 +75,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'cohere':
|
||||
from core.model_providers.providers.cohere_provider import CohereProvider
|
||||
return CohereProvider
|
||||
elif provider_name == 'jina':
|
||||
from core.model_providers.providers.jina_provider import JinaProvider
|
||||
return JinaProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
25
api/core/model_providers/models/embedding/jina_embedding.py
Normal file
25
api/core/model_providers/models/embedding/jina_embedding.py
Normal file
@ -0,0 +1,25 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
|
||||
|
||||
|
||||
class JinaEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = JinaEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Jina: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@ -23,7 +23,8 @@ FUNCTION_CALL_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k'
|
||||
'gpt-35-turbo-16k',
|
||||
'gpt-4-1106-preview'
|
||||
]
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
|
||||
@ -1,27 +1,45 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
|
||||
|
||||
class ChatGLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatGLM(
|
||||
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
if provider_model_kwargs.get('max_length') is not None:
|
||||
extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
endpoint_url=self.credentials.get('api_base'),
|
||||
**provider_model_kwargs
|
||||
request_timeout=60,
|
||||
openai_api_key="1",
|
||||
openai_api_base=self.credentials['api_base'] + '/v1'
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to ChatGLM API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to ChatGLM API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("ChatGLM service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
import cohere
|
||||
import openai
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.error import (LLMAPIConnectionError,
|
||||
LLMAPIUnavailableError,
|
||||
LLMAuthorizationError,
|
||||
LLMBadRequestError, LLMRateLimitError)
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
@ -23,13 +24,20 @@ class CohereReranking(BaseReranking):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> \
|
||||
Optional[List[Document]]:
|
||||
if not documents:
|
||||
return []
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
documents = unique_documents
|
||||
|
||||
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
||||
|
||||
@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from langchain.schema import Document
|
||||
from xinference_client.client.restful.restful_client import Client
|
||||
|
||||
|
||||
class XinferenceReranking(BaseReranking):
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = Client(self.credentials['server_url'])
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
if not documents:
|
||||
return []
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
documents = unique_documents
|
||||
|
||||
model = self.client.get_model(self.credentials['model_uid'])
|
||||
response = model.rerank(query=query, documents=docs, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
||||
for idx, result in enumerate(response['results']):
|
||||
# format document
|
||||
index = result['index']
|
||||
rerank_document = Document(
|
||||
page_content=result['document'],
|
||||
metadata={
|
||||
"doc_id": documents[index].metadata['doc_id'],
|
||||
"doc_hash": documents[index].metadata['doc_hash'],
|
||||
"document_id": documents[index].metadata['document_id'],
|
||||
"dataset_id": documents[index].metadata['dataset_id'],
|
||||
'score': result['relevance_score']
|
||||
}
|
||||
)
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if result['relevance_score'] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
return rerank_documents
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference rerank: {str(ex)}")
|
||||
@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider):
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
'id': 'claude-2.1',
|
||||
'name': 'claude-2.1',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
]
|
||||
else:
|
||||
return []
|
||||
@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'claude-2.1': 200000,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -177,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.anthropic \
|
||||
and hosted_model_providers.anthropic.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
|
||||
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
|
||||
@ -122,6 +122,22 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-1106-preview',
|
||||
'name': 'gpt-4-1106-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-vision-preview',
|
||||
'name': 'gpt-4-vision-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.VISION.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
@ -171,6 +187,8 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
base_model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-4-1106-preview': 4096,
|
||||
'gpt-4-vision-preview': 4096,
|
||||
'gpt-35-turbo': 4096,
|
||||
'gpt-35-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
@ -376,6 +394,18 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4-1106-preview',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4-vision-preview',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-davinci-003',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
|
||||
@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
).update({'last_used': datetime.utcnow()})
|
||||
db.session.commit()
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
||||
"""
|
||||
get provider model.
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
'id': 'chatglm3-6b',
|
||||
'name': 'ChatGLM3-6B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
'id': 'chatglm3-6b-32k',
|
||||
'name': 'ChatGLM3-6B-32K',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'chatglm-6b': 2000,
|
||||
'chatglm2-6b': 32000,
|
||||
'chatglm3-6b-32k': 32000,
|
||||
'chatglm3-6b': 8000,
|
||||
'chatglm2-6b': 8000,
|
||||
}
|
||||
|
||||
max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'endpoint_url': credentials['api_base']
|
||||
}
|
||||
response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
|
||||
|
||||
llm = ChatGLM(
|
||||
max_token=10,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
if response.status_code != 200:
|
||||
raise Exception('ChatGLM Endpoint URL is invalid.')
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the openai hosted model. -1 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedAzureOpenAI(BaseModel):
|
||||
@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1000000
|
||||
paid_min_quantity: int = 20
|
||||
paid_max_quantity: int = 100
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
@ -68,8 +62,6 @@ def init_app(app: Flask):
|
||||
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
@ -85,10 +77,6 @@ def init_app(app: Flask):
|
||||
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
|
||||
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
|
||||
141
api/core/model_providers/providers/jina_provider.py
Normal file
141
api/core/model_providers/providers/jina_provider.py
Normal file
@ -0,0 +1,141 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class JinaProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'jina'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'jina-embeddings-v2-base-en',
|
||||
'name': 'jina-embeddings-v2-base-en',
|
||||
},
|
||||
{
|
||||
'id': 'jina-embeddings-v2-small-en',
|
||||
'name': 'jina-embeddings-v2-small-en',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.EMBEDDINGS:
|
||||
model_class = JinaEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Jina API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
}
|
||||
|
||||
embedding = JinaEmbeddings(
|
||||
model='jina-embeddings-v2-small-en',
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
raise NotImplementedError
|
||||
@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get payment info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
|
||||
@ -2,11 +2,13 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from xinference_client.client.restful.restful_client import Client
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
@ -40,6 +42,8 @@ class XinferenceProvider(BaseModelProvider):
|
||||
model_class = XinferenceModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = XinferenceEmbedding
|
||||
elif model_type == ModelType.RERANKING:
|
||||
model_class = XinferenceReranking
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -113,6 +117,10 @@ class XinferenceProvider(BaseModelProvider):
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
elif model_type == ModelType.RERANKING:
|
||||
rerank_client = Client(credential_kwargs['server_url'])
|
||||
model = rerank_client.get_model(credential_kwargs['model_uid'])
|
||||
model.rerank(query="ping", documents=["ping", "pong"], top_n=2)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
@ -14,5 +14,6 @@
|
||||
"xinference",
|
||||
"openllm",
|
||||
"localai",
|
||||
"cohere"
|
||||
"cohere",
|
||||
"jina"
|
||||
]
|
||||
|
||||
@ -23,8 +23,14 @@
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2": {
|
||||
"prompt": "11.02",
|
||||
"completion": "32.68",
|
||||
"prompt": "8.00",
|
||||
"completion": "24.00",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2.1": {
|
||||
"prompt": "8.00",
|
||||
"completion": "24.00",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
}
|
||||
|
||||
@ -21,6 +21,18 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.002",
|
||||
"completion": "0.0015",
|
||||
|
||||
10
api/core/model_providers/rules/jina.json
Normal file
10
api/core/model_providers/rules/jina.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed",
|
||||
"supported_model_types": [
|
||||
"embeddings"
|
||||
]
|
||||
}
|
||||
@ -6,6 +6,7 @@
|
||||
"model_flexibility": "configurable",
|
||||
"supported_model_types": [
|
||||
"text-generation",
|
||||
"embeddings"
|
||||
"embeddings",
|
||||
"reranking"
|
||||
]
|
||||
}
|
||||
@ -40,7 +40,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
@ -207,22 +207,22 @@ class OrchestratorRuleParser:
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return None
|
||||
continue
|
||||
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
continue
|
||||
dataset_ids.append(dataset.id)
|
||||
if retrieval_model == 'single':
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
top_k = retrieval_model['top_k']
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# dynamically adjust top_k when the remaining token number is not enough to support top_k
|
||||
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
||||
|
||||
score_threshold = None
|
||||
score_threshold_enable = retrieval_model.get("score_threshold_enable")
|
||||
if score_threshold_enable:
|
||||
score_threshold = retrieval_model.get("score_threshold")
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
@ -239,7 +239,7 @@ class OrchestratorRuleParser:
|
||||
dataset_ids=dataset_ids,
|
||||
tenant_id=kwargs['tenant_id'],
|
||||
top_k=dataset_configs.get('top_k', 2),
|
||||
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
|
||||
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
|
||||
69
api/core/third_party/langchain/embeddings/jina_embedding.py
vendored
Normal file
69
api/core/third_party/langchain/embeddings/jina_embedding.py
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
"""Wrapper around Jina embedding models."""
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around Jina embedding models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.invoke_embedding(text=text)
|
||||
embeddings.append(result)
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def invoke_embedding(self, text):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"input": [
|
||||
text
|
||||
]
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
'https://api.jina.ai/v1/embeddings',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return json_response["data"][0]["embedding"]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from httpx import Limits
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.utils import get_from_dict_or_env, check_package_version
|
||||
from pydantic import root_validator
|
||||
|
||||
@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic):
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"],
|
||||
timeout=values["default_request_timeout"],
|
||||
max_retries=0,
|
||||
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
|
||||
max_retries=0
|
||||
)
|
||||
values["async_client"] = anthropic.AsyncAnthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic):
|
||||
"Please it install it with `pip install anthropic`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{message.content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
@ -24,7 +24,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +82,8 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
@ -192,7 +193,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
'search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
@ -210,13 +211,13 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model[
|
||||
'score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model[
|
||||
'reranking_model'] if retrieval_model[
|
||||
|
||||
@ -25,7 +25,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
@ -106,11 +106,11 @@ class DatasetRetrieverTool(BaseTool):
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents,
|
||||
@ -124,12 +124,12 @@ class DatasetRetrieverTool(BaseTool):
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
documents = hybrid_rerank.rerank(query, documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
self.top_k)
|
||||
else:
|
||||
documents = []
|
||||
@ -158,7 +158,8 @@ class DatasetRetrieverTool(BaseTool):
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
|
||||
|
||||
@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_doc_ids(self, doc_ids: list):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["doc_id"] in {doc_ids}',
|
||||
|
||||
@ -60,7 +60,7 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||
|
||||
|
||||
def _default_score_normalizer(val: float) -> float:
|
||||
return 1 - 1 / (1 + np.exp(val))
|
||||
return 1 - val
|
||||
|
||||
|
||||
def _json_serializable(value: Any) -> Any:
|
||||
@ -243,7 +243,8 @@ class Weaviate(VectorStore):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_bm25(query=content).with_limit(k).do()
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
@ -380,14 +381,14 @@ class Weaviate(VectorStore):
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(k)
|
||||
.with_additional("vector")
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
query_obj.with_near_text(content)
|
||||
.with_limit(k)
|
||||
.with_additional("vector")
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
|
||||
@ -397,7 +398,7 @@ class Weaviate(VectorStore):
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
score = np.dot(res["_additional"]["vector"], embedded_query)
|
||||
score = res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
return docs_and_scores
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from langchain.vectorstores import Weaviate
|
||||
from core.vector_store.vector.weaviate import Weaviate
|
||||
|
||||
|
||||
class WeaviateVectorStore(Weaviate):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user