Compare commits

...

97 Commits

Author SHA1 Message Date
43741ad5d1 feat: bump version to 0.3.34 (#1788) 2023-12-19 14:08:47 +08:00
8dec406161 chore: enchance ext name (#1787) 2023-12-19 14:03:24 +08:00
58f8d74591 fix: unstructured file extension (#1785) 2023-12-19 12:09:48 +08:00
867fc61b12 fix: web app text (#1784) 2023-12-19 11:45:16 +08:00
8e2e477a7f chore: enchance annotation ui (#1781) 2023-12-19 10:25:54 +08:00
9b34f5a9ff feat: unstructured frontend (#1777) 2023-12-18 23:28:25 +08:00
5e34f938c1 Feat/add unstructured support (#1780)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 23:24:06 +08:00
2fd56cb01c Fix/vdb index issue (#1776)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 21:33:54 +08:00
4f0e272549 fix: add then eidt annotion cause show bug (#1775) 2023-12-18 19:33:48 +08:00
1a5279a3ef fix: get billing info in self-hosted edition from current workspace (#1774) 2023-12-18 17:54:16 +08:00
7775f5785f chore: update annotation reply english i18n (#1773) 2023-12-18 17:13:15 +08:00
2de73991ff feat: only tenant owner can subscription. (#1770) 2023-12-18 16:59:31 +08:00
354d033e60 fix: not owner can not pay (#1772) 2023-12-18 16:54:47 +08:00
ebc2cdad2e fix annotation query exception (#1771)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 16:48:34 +08:00
5bb841935e feat: custom webapp logo (#1766) 2023-12-18 16:25:37 +08:00
65fd4b39ce feat: annotation management frontend (#1764) 2023-12-18 15:41:24 +08:00
96d2de2258 fix annotation reply in universal chat (#1768)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 15:04:17 +08:00
a71f2863ac Annotation management (#1767)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 13:10:05 +08:00
a9b942981d fix: issue templates not render correctly (#1763) 2023-12-18 09:22:11 +08:00
4b1ba2ec21 feat: remove billing config. (#1761) 2023-12-17 13:22:45 +08:00
c09184fd94 update bm25 search properties (#1758)
Co-authored-by: Blade <zhangxiaobin@unixyz.cn>
2023-12-15 12:28:03 +08:00
b0d8d196e1 azure openai add gpt-4-1106-preview、gpt-4-vision-preview models (#1751)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-14 09:55:30 +08:00
7c43123956 feat: can replace logo. (#1752) 2023-12-13 20:21:39 +08:00
eede84eb9e feat: web app support some feature (#1753) 2023-12-13 20:21:11 +08:00
b5b20234e9 feat: update pricing (#1749) 2023-12-13 16:41:40 +08:00
5beb298e47 chore: update term links (#1748) 2023-12-13 15:12:27 +08:00
6b499b9a16 remove stripe and anthropic. (#1746) 2023-12-12 17:59:07 +08:00
4c639961f5 add self checks to issues and discussions templates (#1742) 2023-12-11 23:25:17 +08:00
dfd3f507fb fix: ad block disabled tracking would block ga then can not pay (#1741) 2023-12-11 16:36:58 +08:00
d5695b3170 check rerank document is not empty (#1740)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 16:12:11 +08:00
994fceece3 fix: qa regex (#1738) 2023-12-11 15:53:37 +08:00
8c451eb0e6 fix only full text search in app issue (#1736)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 15:34:29 +08:00
79b4366203 fix: server component use translate errorts lint error (#1732) 2023-12-11 10:15:49 +08:00
3675d2eae8 fix: prompt null parse var error (#1731) 2023-12-11 10:06:01 +08:00
38b55d2186 fix: default types (#1728) 2023-12-09 23:38:07 +08:00
bee0d12455 fix: remove postgresql default charset config (#1720) 2023-12-08 13:22:04 +08:00
13f2c90a7b bump version to 0.3.33 (#1719) 2023-12-08 13:13:21 +08:00
a3dca3dabc fix: auto generate type error in controllers/web/conversation.py (#1718) 2023-12-08 11:15:07 +08:00
e5c7a81ce3 fix ascii codec error, by using utf8 (#1608)
Co-authored-by: crazywoola <427733928@qq.com>
2023-12-08 09:05:58 +08:00
8b0100523b Feat/regenrate conversation in embeded window (#1708) 2023-12-07 13:17:07 +08:00
1350599c0b fix: process document priority tip (#1712) 2023-12-07 10:38:33 +08:00
bc54cdc537 refactor: typo in dataset docstore (#1711) 2023-12-07 09:24:52 +08:00
5d10cf0fe6 fix: error Class 'builtins.list' is not mapped (#1710) 2023-12-07 09:24:39 +08:00
7b8a10f3ea feat: billing enhancement 20231204 (#1691)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-05 16:53:55 +08:00
cb3a55dae6 feat: remove documents limit (#1697) 2023-12-05 16:53:40 +08:00
5789d76582 fix(web): reserve default copy behavior (#1693) 2023-12-05 16:34:12 +08:00
2e588ae221 feat: use gtag instead gtm (#1695) 2023-12-05 15:05:05 +08:00
b5dd948e56 feat: add upgrade ga test (#1690) 2023-12-04 18:16:59 +08:00
1263b7de75 fix: vector_size convert bytes to MB. (#1684) 2023-12-04 10:51:40 +08:00
75a6122173 feat: SaaS price plan frontend (#1683)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-12-03 22:10:16 +08:00
053102f433 Feat/dify billing (#1679)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: takatost <takatost@users.noreply.github.com>
2023-12-03 20:59:29 +08:00
d3a2c0ed34 fix wrong syntax of type definitions (#1678) 2023-12-03 20:59:13 +08:00
8fbc374f31 chore(deps): bump word-wrap from 1.2.3 to 1.2.5 in /web (#1681)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-03 20:50:01 +08:00
08b7ebba91 chore(deps): bump semver from 5.7.1 to 5.7.2 in /web (#1680)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-03 20:49:52 +08:00
a1cd043fdc fix: Incorrect order of embedded documents in CacheEmbedding (#1671) 2023-12-03 19:07:00 +08:00
671a8e7972 Doc/use proper links (#1673) 2023-12-02 14:08:10 +08:00
efa16dbb44 feat: drag to upload image (#1666) 2023-12-01 16:50:22 +08:00
a6241be42a fix: Fix typo in documentation: change 'converation' to 'conversation' (#1665) 2023-12-01 13:12:29 +08:00
faa88aafe8 feat: clipboard paste (#1663) 2023-12-01 10:04:14 +08:00
1b3a98425f fix: app setting click pop (#1660) 2023-12-01 09:44:32 +08:00
22bc9ddc73 Hotfix/fix documents index mismatch error in rerank (#1662)
Co-authored-by: baomi.wbm <baomi.wbm@dtwave-inc.com>
2023-11-30 22:03:20 +08:00
0423775687 fix: explore page header menu hide in safari (#1658) 2023-11-30 16:13:42 +08:00
307c170fb6 fix: Jina AI logo (#1656) 2023-11-30 15:41:59 +08:00
0e04fcc071 chore(docker-compose): use proper comment indentation for users to easily toggle on and off features (#1648)
Signed-off-by: Neko Ayaka <neko@ayaka.moe>
2023-11-30 09:46:36 +08:00
4322b17a81 fix: app card hover selector (#1646) 2023-11-29 14:58:27 +08:00
451af66be0 feat: add jina embedding (#1647)
Co-authored-by: takatost <takatost@gmail.com>
2023-11-29 14:58:11 +08:00
454577c6b1 Remove legacy docker startup docs in frontend (#1645) 2023-11-29 13:26:35 +08:00
53be4d2712 fix #1637 (#1638)
Co-authored-by: baomi.wbm <baomi.wbm@dtwave-inc.com>
2023-11-28 20:05:50 +08:00
3c37fd37fa fix: batch mobile layout fixes (#1641) 2023-11-28 20:05:19 +08:00
cf0ba794d7 fix: old webapp url still valid (#1643) 2023-11-28 20:04:46 +08:00
c21e2063fe Update README.md (#1636) 2023-11-28 15:23:05 +08:00
ad037c6615 feat: add items (#1633) 2023-11-27 19:38:00 +08:00
7bbfac5dba Chore: change dataset's i18n to knowledge (#1629) 2023-11-27 17:22:16 +08:00
80ddb00f10 fix: score_threshold_enabled variable (#1627) 2023-11-27 15:38:05 +08:00
74b2260ba6 fix score_threshold_enabled name (#1626)
Co-authored-by: jyong <jyong@dify.ai>
2023-11-27 15:34:45 +08:00
603e55f252 Update README.md (#1623) 2023-11-27 14:20:25 +08:00
a9c1c7d239 feat: fe mobile responsive next (#1609) 2023-11-27 11:47:48 +08:00
3cc697832a feat: bump version to 0.3.32 (#1620) 2023-11-25 16:43:31 +08:00
bb98f5756a feat: add xinference rerank model (#1619) 2023-11-25 16:23:24 +08:00
e1d2203371 fix: provider chatglm tests error (#1618) 2023-11-25 16:04:36 +08:00
93467cb363 fix: dataset tool missing in n-to-1 retrieve mode (#1617) 2023-11-25 16:04:22 +08:00
ea526d0822 feat: chatglm3 support (#1616) 2023-11-25 15:37:07 +08:00
0e627c920f feat: xinference rerank model support (#1615) 2023-11-25 03:56:00 +08:00
ea35f1dce1 feat: bump version to 0.3.31-fix3 (#1606) 2023-11-22 19:40:52 +08:00
a5b80c9d1f Fix/multi thread parameter (#1604) 2023-11-22 18:31:29 +08:00
f704094a5f fix hybrid search when document is none (#1603)
Co-authored-by: jyong <jyong@dify.ai>
2023-11-22 17:53:42 +08:00
1f58f15bff feat: optimize db connections in thread (#1601) 2023-11-22 16:55:59 +08:00
b930716745 fix weaviate hybrid search issue (#1600)
Co-authored-by: jyong <jyong@dify.ai>
2023-11-22 16:41:20 +08:00
9587479b76 fix: chat token spent info style (#1597) 2023-11-22 15:22:50 +08:00
3c0fbf3a6a fix sql transaction error in statistic API (#1586) 2023-11-22 14:28:21 +08:00
caa330c91f feat: bump version to 0.3.31-fix2 (#1592) 2023-11-22 01:53:40 +08:00
4a55d5729d feat: add anthropic claude-2.1 support (#1591) 2023-11-22 01:46:19 +08:00
d6a6697891 fix: safari can not in (#1590) 2023-11-21 20:25:23 +08:00
778cfb37a2 feat: bump version to 0.3.31-fix1 (#1589) 2023-11-21 17:34:41 +08:00
ce85ee3aa6 Update docker-compose.yaml (#1587) 2023-11-21 17:33:35 +08:00
b23de4affc fix: chat on start bug (#1588) 2023-11-21 17:26:49 +08:00
d8a7e894aa fix: retrieval test page hide rerank model also hide retrieval config (#1585) 2023-11-21 16:07:47 +08:00
483 changed files with 14044 additions and 2431 deletions

View File

@ -1,11 +1,18 @@
name: "🕷️ Bug report"
description: Report errors or unexpected behavior [please use English :]
description: Report errors or unexpected behavior
labels:
- bug
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version
@ -21,8 +28,8 @@ body:
multiple: true
options:
- Cloud
- Self Hosted
- Other (please specify in "Steps to Reproduce")
- Self Hosted (Docker)
- Self Hosted (Source)
validations:
required: true

View File

@ -1,8 +1,16 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation [please use English :]
description: Report issues in our documentation
labels:
- ducumentation
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of requested docs changes

View File

@ -1,8 +1,17 @@
name: "⭐ Feature or enhancement request"
description: Propose something new. [please use English :]
description: Propose something new.
labels:
- enhancement
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Description of the new feature / enhancement

View File

@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :]
labels:
- help-wanted
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of the help you need

View File

@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :]
labels:
- translation
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version

View File

@ -34,9 +34,7 @@ jobs:
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v4

View File

@ -34,9 +34,7 @@ jobs:
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v4

View File

@ -20,6 +20,8 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
[v0.3.31:Surpassing the Assistants API Dify's RAG Demonstrates an Impressive 20% Improvement.](https://dify.ai/blog/dify-ai-rag-technology-upgrade-performance-improvement-qa-accuracy)
**Dify** is an LLM application development platform that has already seen over **100,000** applications built on Dify.AI. It integrates the concepts of Backend as a Service and LLMOps, covering the core tech stack required for building generative AI-native applications, including a built-in RAG engine. With Dify, **you can self-deploy capabilities similar to Assistants API and GPTs based on any LLMs.**
![](./images/demo.png)

View File

@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
@ -119,10 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
ETL_TYPE=dify
UNSTRUCTURED_API_URL=

View File

@ -53,12 +53,3 @@
```
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. Start frontend
You can start the frontend by running `npm install && npm run dev` in web/ folder, or you can use docker to start the frontend, for example:
```
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5001 --name web-self-hosted langgenius/dify-web:latest
```
This will start a dify frontend, now you are all set, happy coding!

View File

@ -20,7 +20,7 @@ from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension
from extensions.ext_database import db
from extensions.ext_login import login_manager
@ -96,7 +96,6 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
# Flask-Login configuration

View File

@ -28,7 +28,7 @@ from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
import secrets
import base64
@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@ -766,3 +790,4 @@ def register_commands(app):
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)

View File

@ -1,11 +1,8 @@
# -*- coding:utf-8 -*-
import os
from datetime import timedelta
import dotenv
from extensions.ext_database import db
from extensions.ext_redis import redis_client
dotenv.load_dotenv()
@ -15,6 +12,7 @@ DEFAULTS = {
'DB_HOST': 'localhost',
'DB_PORT': '5432',
'DB_DATABASE': 'dify',
'DB_CHARSET': '',
'REDIS_HOST': 'localhost',
'REDIS_PORT': '6379',
'REDIS_DB': '0',
@ -43,25 +41,21 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
'INVITE_EXPIRY_HOURS': 72
'INVITE_EXPIRY_HOURS': 72,
'ETL_TYPE': 'dify',
}
@ -91,7 +85,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.31"
self.CURRENT_VERSION = "0.3.34"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -149,10 +143,12 @@ class Config:
# ------------------------
db_credentials = {
key: get_env(key) for key in
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE', 'DB_CHARSET']
}
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
db_extras = f"?client_encoding={db_credentials['DB_CHARSET']}" if db_credentials['DB_CHARSET'] else ""
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
@ -240,7 +236,6 @@ class Config:
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
# Dataset Configurations.
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# File upload Configurations.
@ -267,8 +262,6 @@ class Config:
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@ -280,14 +273,13 @@ class Config:
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
class CloudEditionConfig(Config):
@ -301,6 +293,3 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import extension, setup, version, apikey, admin
# Import app controllers
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
@ -26,5 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe
from .billing import billing

View File

@ -0,0 +1,290 @@
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, marshal
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
annotation_hit_history_fields
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from flask import request
class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args()
if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError('Unsupported annotation reply action')
return result, 200
class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
response = {
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')

View File

@ -12,7 +12,7 @@ from constants.model_template import model_templates, demo_model_templates
from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
@ -57,6 +57,7 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Create app"""
parser = reqparse.RequestParser()

View File

@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
raise InternalServerError()
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400

View File

@ -6,22 +6,23 @@ from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get app info
app = _get_app(app_id)
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('content', type=str, location='json')
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
message_id = str(args['message_id'])
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
if annotation:
annotation.content = args['content']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['content'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
return {'result': 'success'}
return annotation
class MessageAnnotationCountApi(Resource):
@ -249,7 +230,7 @@ class MessageMoreLikeThisApi(Resource):
raise InternalServerError()
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
"""Modify app model config"""
app_id = str(app_id)
app_model = _get_app(app_id)
app = _get_app(app_id)
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
mode=app.mode
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
app_id=app.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(
app_model,
app,
app_model_config=new_app_model_config
)

View File

@ -62,16 +62,15 @@ class DailyConversationStatistic(Resource):
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'conversation_count': i.conversation_count
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'conversation_count': i.conversation_count
})
return jsonify({
'data': response_data
@ -124,16 +123,15 @@ class DailyTerminalsStatistic(Resource):
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
return jsonify({
'data': response_data
@ -187,18 +185,17 @@ class DailyTokenCostStatistic(Resource):
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
'total_price': i.total_price,
'currency': 'USD'
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
'total_price': i.total_price,
'currency': 'USD'
})
return jsonify({
'data': response_data
@ -256,16 +253,15 @@ LEFT JOIN conversations c on c.id=subquery.conversation_id
GROUP BY date
ORDER BY date"""
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
for i in rs:
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
return jsonify({
'data': response_data
@ -320,20 +316,19 @@ class UserSatisfactionRateStatistic(Resource):
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
})
return jsonify({
'data': response_data
})
'data': response_data
})
class AverageResponseTimeStatistic(Resource):
@ -383,16 +378,15 @@ class AverageResponseTimeStatistic(Resource):
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'latency': round(i.latency * 1000, 4)
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'latency': round(i.latency * 1000, 4)
})
return jsonify({
'data': response_data
@ -447,16 +441,15 @@ WHERE app_id = :app_id'''
sql_query += ' GROUP BY date order by date'
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
response_data = []
for i in rs:
response_data.append({
'date': str(i.date),
'tps': round(i.tokens_per_second, 4)
})
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'tps': round(i.tokens_per_second, 4)
})
return jsonify({
'data': response_data

View File

@ -0,0 +1,61 @@
from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import only_edition_cloud
from libs.login import login_required
from services.billing_service import BillingService
class BillingInfo(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
edition = current_app.config['EDITION']
if edition != 'CLOUD':
return {"enabled": False}
return BillingService.get_info(current_user.current_tenant_id)
class Subscription(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
BillingService.is_tenant_owner(current_user)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)
class Invoices(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
BillingService.is_tenant_owner(current_user)
return BillingService.get_invoices(current_user.email)
api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')

View File

@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

View File

@ -16,7 +16,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \
InvalidMetadataError, ArchivedDocumentImmutableError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.indexing_runner import IndexingRunner
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
@ -194,6 +194,7 @@ class DatasetDocumentListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check('vector_space')
def post(self, dataset_id):
dataset_id = str(dataset_id)
@ -252,6 +253,7 @@ class DatasetInitApi(Resource):
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check('vector_space')
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
@ -693,6 +695,7 @@ class DocumentStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
@ -770,14 +773,6 @@ class DocumentStatusApi(DocumentResource):
if not document.archived:
raise InvalidActionError('Document is not archived.')
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + 1
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
document.archived = False
document.archived_at = None
document.archived_by = None
@ -856,21 +851,6 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentLimitApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""get document limit"""
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
return {
'documents_count': documents_count,
'documents_limit': tenant_document_count
}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@ -896,4 +876,3 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentLimitApi, '/datasets/limit')

View File

@ -11,7 +11,7 @@ from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from libs.login import login_required
@ -114,6 +114,7 @@ class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -200,6 +201,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -250,6 +252,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -344,6 +347,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)

View File

@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
return {'content': text}
class FileeSupportTypApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
etl_type = current_app.config['ETL_TYPE']
if etl_type == 'Unstructured':
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
else:
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
return {'allowed_extensions': allowed_extensions}
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileeSupportTypApi, '/files/support-type')

View File

@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource):
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@ -14,6 +14,7 @@ from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from controllers.console.wraps import cloud_edition_billing_resource_check
class InstalledAppsListApi(Resource):
@ -47,6 +48,7 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('apps')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')

View File

@ -105,7 +105,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise InternalServerError()
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -104,7 +104,7 @@ class UniversalChatStopApi(UniversalChatResource):
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw
}
@marshal_with(parameters_fields)
@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
}

View File

@ -1,61 +0,0 @@
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')

View File

@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse, marshal_with, abort, fields, marsh
import services
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Account, TenantAccountJoin
@ -47,6 +47,7 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('members')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('emails', type=str, required=True, location='json', action='append')

View File

@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_checkout_service import ProviderCheckoutService
from services.provider_service import ProviderService
from services.billing_service import BillingService
class ModelProviderListApi(Resource):
@ -115,7 +115,7 @@ class ModelProviderModelValidateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
@ -155,7 +155,7 @@ class ModelProviderModelUpdateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
@ -184,7 +184,7 @@ class ModelProviderModelUpdateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
if provider_name != 'anthropic':
raise ValueError(f'provider name {provider_name} is invalid')
return {
'url': provider_checkout.get_checkout_url()
}
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
return data
class ModelProviderFreeQuotaSubmitApi(Resource):

View File

@ -10,12 +10,15 @@ from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Tenant
import services
from services.account_service import TenantService
from services.workspace_service import WorkspaceService
from services.file_service import FileService
provider_fields = {
'provider_name': fields.String,
@ -34,6 +37,7 @@ tenant_fields = {
'providers': fields.List(fields.Nested(provider_fields)),
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'),
}
tenants_fields = {
@ -130,6 +134,61 @@ class SwitchWorkspaceApi(Resource):
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('remove_webapp_brand', type=bool, location='json')
parser.add_argument('replace_webapp_logo', type=str, location='json')
args = parser.parse_args()
custom_config_dict = {
'remove_webapp_brand': args['remove_webapp_brand'],
'replace_webapp_logo': args['replace_webapp_logo'],
}
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.custom_config_dict = custom_config_dict
db.session.commit()
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
extension = file.filename.split('.')[-1]
if extension.lower() not in ['svg', 'png']:
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file, current_user, True)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return { 'id': upload_file.id }, 201
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
@ -137,3 +196,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')

View File

@ -5,6 +5,7 @@ from flask import current_app, abort
from flask_login import current_user
from controllers.console.workspace.error import AccountNotInitializedError
from services.billing_service import BillingService
def account_initialization_required(view):
@ -41,3 +42,35 @@ def only_edition_self_hosted(view):
return view(*args, **kwargs)
return decorated
def cloud_edition_billing_resource_check(resource: str,
error_msg: str = "You have reached the limit of your subscription."):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if current_app.config['EDITION'] == 'CLOUD':
tenant_id = current_user.current_tenant_id
billing_info = BillingService.get_info(tenant_id)
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
annotation_quota_limit = billing_info['annotation_quota_limit']
if resource == 'members' and 0 < members['limit'] <= members['size']:
abort(403, error_msg)
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
abort(403, error_msg)
elif resource == 'workspace_custom' and not billing_info['can_replace_logo']:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']:
abort(403, error_msg)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor

View File

@ -1,10 +1,12 @@
from flask import request, Response
from flask_restful import Resource
from werkzeug.exceptions import NotFound
import services
from controllers.files import api
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.account_service import TenantService
class ImagePreviewApi(Resource):
@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class WorkspaceWebappLogoApi(Resource):
def get(self, workspace_id):
workspace_id = str(workspace_id)
custom_config = TenantService.get_custom_config(workspace_id)
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
if not webapp_logo_file_id:
raise NotFound(f'webapp logo is not found')
try:
generator, mimetype = FileService.get_public_image_preview(
webapp_logo_file_id,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
class UnsupportedFileTypeError(BaseHTTPException):

View File

@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
args = parser.parse_args()
@ -150,7 +150,7 @@ class ChatStopApi(AppApiResource):
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -11,7 +11,7 @@ from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
@ -24,6 +24,7 @@ from services.file_service import FileService
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
@ -88,6 +89,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
@ -147,6 +149,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
@ -212,6 +215,7 @@ class DocumentAddByFileApi(DatasetApiResource):
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}

View File

@ -3,7 +3,7 @@ from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
@ -14,6 +14,8 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
@ -144,6 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
@cloud_edition_billing_resource_check('vector_space', 'dataset')
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)

View File

@ -11,6 +11,7 @@ from libs.login import _get_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.model import ApiToken, App
from services.billing_service import BillingService
def validate_app_token(view=None):
@ -40,6 +41,33 @@ def validate_app_token(view=None):
return decorator
def cloud_edition_billing_resource_check(resource: str,
api_token_type: str,
error_msg: str = "You have reached the limit of your subscription."):
def interceptor(view):
def decorated(*args, **kwargs):
if current_app.config['EDITION'] == 'CLOUD':
api_token = validate_and_get_api_token(api_token_type)
billing_info = BillingService.get_info(api_token.tenant_id)
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
if resource == 'members' and 0 < members['limit'] <= members['size']:
raise Unauthorized(error_msg)
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
raise Unauthorized(error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
raise Unauthorized(error_msg)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)

View File

@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@ -68,7 +68,7 @@ class ConversationRenameApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@ -139,7 +139,7 @@ class MessageMoreLikeThisApi(WebApiResource):
raise InternalServerError()
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@ -1,11 +1,15 @@
# -*- coding:utf-8 -*-
import os
from flask_restful import fields, marshal_with
from flask import current_app
from werkzeug.exceptions import Forbidden
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.model import Site
from services.billing_service import BillingService
class AppSiteApi(WebApiResource):
@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
'site': fields.Nested(site_fields),
'model_config': fields.Nested(model_config_fields, allow_null=True),
'plan': fields.String,
'can_replace_logo': fields.Boolean,
'custom_config': fields.Raw(attribute='custom_config'),
}
@marshal_with(app_fields)
@ -50,7 +56,14 @@ class AppSiteApi(WebApiResource):
if not site:
raise Forbidden()
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
edition = os.environ.get('EDITION')
can_replace_logo = False
if edition == 'CLOUD':
info = BillingService.get_info(app_model.tenant_id)
can_replace_logo = info['can_replace_logo']
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
api.add_resource(AppSiteApi, '/site')
@ -59,7 +72,7 @@ api.add_resource(AppSiteApi, '/site')
class AppSiteInfo:
"""Class to store site information."""
def __init__(self, tenant, app, site, end_user):
def __init__(self, tenant, app, site, end_user, can_replace_logo):
"""Initialize AppSiteInfo instance."""
self.app_id = app.id
self.end_user_id = end_user
@ -67,6 +80,16 @@ class AppSiteInfo:
self.site = site
self.model_config = None
self.plan = tenant.plan
self.can_replace_logo = can_replace_logo
if can_replace_logo:
base_url = current_app.config.get('FILES_URL')
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
self.custom_config = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
}
if app.enable_site and site.prompt_public:
app_model_config = app.app_model_config

View File

@ -40,7 +40,7 @@ def decode_jwt_token():
site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model:
raise NotFound()
if not app_code and not site:
if not app_code or not site:
raise Unauthorized('Site URL is no longer valid.')
if app_model.enable_site is False:
raise Unauthorized('Site is disabled.')

View File

@ -59,7 +59,7 @@ class AgentExecutor:
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,

View File

@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.embedding.cached_embedding import CacheEmbedding
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.dataset import Dataset
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
class Completion:
@ -33,7 +38,7 @@ class Completion:
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
auto_generate_name: bool = True, from_source: str = 'console'):
"""
errors: ProviderTokenNotInitError
"""
@ -109,7 +114,10 @@ class Completion:
fake_response=str(e)
)
return
# check annotation reply
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
if annotation_reply:
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
@ -166,17 +174,18 @@ class Completion:
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation = ModerationFactory(type, app_id, tenant_id,
app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
@ -324,6 +333,81 @@ class Completion:
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
from_source: str) -> bool:
"""Get memory messages."""
app_model_config = conversation_message_task.app_model_config
app = conversation_message_task.app
annotation_reply = app_model_config.annotation_reply_dict
if annotation_reply['enabled']:
try:
score_threshold = annotation_reply.get('score_threshold', 1)
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=app.tenant_id,
model_provider_name=embedding_provider_name,
model_name=embedding_model_name
)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
)
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search(
conversation_message_task.query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app.id,
annotation.question,
annotation.content,
conversation_message_task.query,
conversation_message_task.user.id,
conversation_message_task.message.id,
from_source,
score)
return True
except Exception as e:
logging.warning(f'Query annotation failed, exception: {str(e)}.')
return False
return False
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,

View File

@ -319,9 +319,13 @@ class ConversationMessageTask:
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account | EndUser], task_id: str,
def __init__(self, user: Union[Account, EndUser], task_id: str,
message: Message, conversation: Conversation,
chain_pub: bool = False, agent_thought_pub: bool = False):
self._channel = PubHandler.generate_channel_name(user, task_id)
@ -334,7 +338,7 @@ class PubHandler:
self._agent_thought_pub = agent_thought_pub
@classmethod
def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
if not user:
raise ValueError("user is required")
@ -342,7 +346,7 @@ class PubHandler:
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id)
@ -435,7 +439,7 @@ class PubHandler:
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': self._conversation.id,
}
}
if retriever_resource:
@ -446,6 +450,30 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
content = {
'event': 'annotation',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id,
'text': text,
'annotation_id': annotation_id,
'annotation_author_name': annotation_author_name
}
}
self._message.answer = text
self._message.provider_response_latency = time.perf_counter() - start_at
db.session.commit()
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',
@ -454,7 +482,7 @@ class PubHandler:
redis_client.publish(self._channel, json.dumps(content))
@classmethod
def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
content = {
'error': type(e).__name__,
'description': e.description if getattr(e, 'description', None) is not None else str(e)
@ -467,7 +495,7 @@ class PubHandler:
return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def ping(cls, user: Union[Account | EndUser], task_id: str):
def ping(cls, user: Union[Account, EndUser], task_id: str):
content = {
'event': 'ping'
}
@ -476,7 +504,7 @@ class PubHandler:
redis_client.publish(channel, json.dumps(content))
@classmethod
def stop(cls, user: Union[Account | EndUser], task_id: str):
def stop(cls, user: Union[Account, EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
redis_client.setex(stopped_cache_key, 600, 1)

View File

@ -3,7 +3,8 @@ from pathlib import Path
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from flask import current_app
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
@ -49,14 +57,34 @@ class FileExtractor:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)

View File

@ -0,0 +1,41 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredEmailLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,48 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMarkdownLoader(BaseLoader):
"""Load md files.
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMsgLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTXLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredTextLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredXmlLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -8,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore:
class DatasetDocumentStore:
def __init__(
self,
dataset: Dataset,
@ -20,7 +20,7 @@ class DatesetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
return cls(**config_dict)
def to_dict(self) -> Dict[str, Any]:

View File

@ -18,31 +18,30 @@ class CacheEmbedding(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
text_embeddings = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
text_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_texts.append(text)
embedding_queue_indices.append(i)
if embedding_queue_texts:
if embedding_queue_indices:
try:
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
i = 0
normalized_embedding_results = []
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
for i, indice in enumerate(embedding_queue_indices):
hash = helper.generate_text_hash(texts[indice])
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
text_embeddings[indice] = normalized_embedding
embedding.set_embedding(normalized_embedding)
db.session.add(embedding)
db.session.commit()
@ -52,10 +51,7 @@ class CacheEmbedding(Embeddings):
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(normalized_embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:

View File

@ -10,7 +10,7 @@ from flask import current_app
from extensions.ext_storage import storage
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
class UploadFileParser:

View File

@ -32,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError

View File

@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)

View File

@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
@ -121,6 +120,16 @@ class MilvusVectorIndex(BaseVectorIndex):
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()

View File

@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
@ -33,7 +38,8 @@ class VectorIndex:
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

View File

@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
@ -141,6 +142,27 @@ class WeaviateVectorIndex(BaseVectorIndex):
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.dataset_docstore import DatasetDocumentStore
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.model_providers.error import ProviderTokenNotInitError
@ -106,7 +106,8 @@ class IndexingRunner:
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
for document_segment in document_segments:
db.session.delete(document_segment)
db.session.commit()
# load file
@ -396,7 +397,7 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=False)
text_docs = FileExtractor.load(file_detail, is_automatic=True)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
@ -474,7 +475,7 @@ class IndexingRunner:
)
# save node to document segment
doc_store = DatesetDocumentStore(
doc_store = DatasetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
document_id=dataset_document.id
@ -631,8 +632,8 @@ class IndexingRunner:
return text
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [
{

View File

@ -75,6 +75,9 @@ class ModelProviderFactory:
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
elif provider_name == 'jina':
from core.model_providers.providers.jina_provider import JinaProvider
return JinaProvider
else:
raise NotImplementedError

View File

@ -0,0 +1,25 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
class JinaEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = JinaEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Jina: {str(ex)}")
else:
return ex

View File

@ -23,7 +23,8 @@ FUNCTION_CALL_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k'
'gpt-35-turbo-16k',
'gpt-4-1106-preview'
]
class AzureOpenAIModel(BaseLLM):

View File

@ -1,27 +1,45 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError
from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
if provider_model_kwargs.get('max_length') is not None:
extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['api_base'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to ChatGLM API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to ChatGLM API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("ChatGLM service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -1,14 +1,15 @@
import logging
from typing import Optional, List
from typing import List, Optional
import cohere
import openai
from langchain.schema import Document
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.error import (LLMAPIConnectionError,
LLMAPIUnavailableError,
LLMAuthorizationError,
LLMBadRequestError, LLMRateLimitError)
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider
from langchain.schema import Document
class CohereReranking(BaseReranking):
@ -23,13 +24,20 @@ class CohereReranking(BaseReranking):
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> \
Optional[List[Document]]:
if not documents:
return []
docs = []
doc_id = []
unique_documents = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = []

View File

@ -0,0 +1,62 @@
import logging
from typing import List, Optional
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider
from langchain.schema import Document
from xinference_client.client.restful.restful_client import Client
class XinferenceReranking(BaseReranking):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = Client(self.credentials['server_url'])
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
if not documents:
return []
docs = []
doc_id = []
unique_documents = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
model = self.client.get_model(self.credentials['model_uid'])
response = model.rerank(query=query, documents=docs, top_n=top_k)
rerank_documents = []
for idx, result in enumerate(response['results']):
# format document
index = result['index']
rerank_document = Document(
page_content=result['document'],
metadata={
"doc_id": documents[index].metadata['doc_id'],
"doc_hash": documents[index].metadata['doc_hash'],
"document_id": documents[index].metadata['document_id'],
"dataset_id": documents[index].metadata['dataset_id'],
'score': result['relevance_score']
}
)
# score threshold check
if score_threshold is not None:
if result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return rerank_documents
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference rerank: {str(ex)}")

View File

@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'id': 'claude-2.1',
'name': 'claude-2.1',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'claude-2',
@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
]
else:
return []
@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider):
:param model_type:
:return:
"""
model_max_tokens = {
'claude-instant-1': 100000,
'claude-2': 100000,
'claude-2.1': 200000,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0),
)
@classmethod
@ -177,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):
return False
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
if hosted_model_providers.anthropic \
and hosted_model_providers.anthropic.paid_enabled:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

View File

@ -122,6 +122,22 @@ class AzureOpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-1106-preview',
'name': 'gpt-4-1106-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-vision-preview',
'name': 'gpt-4-vision-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.VISION.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
@ -171,6 +187,8 @@ class AzureOpenAIProvider(BaseModelProvider):
base_model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-4-1106-preview': 4096,
'gpt-4-vision-preview': 4096,
'gpt-35-turbo': 4096,
'gpt-35-turbo-16k': 16384,
'text-davinci-003': 4097,
@ -376,6 +394,18 @@ class AzureOpenAIProvider(BaseModelProvider):
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4-1106-preview',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4-vision-preview',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-davinci-003',
model_type=ModelType.TEXT_GENERATION,

View File

@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
).update({'last_used': datetime.utcnow()})
db.session.commit()
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
return None
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.

View File

@ -2,6 +2,7 @@ import json
from json import JSONDecodeError
from typing import Type
import requests
from langchain.llms import ChatGLM
from core.helper import encrypter
@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
'id': 'chatglm3-6b',
'name': 'ChatGLM3-6B',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
'id': 'chatglm3-6b-32k',
'name': 'ChatGLM3-6B-32K',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'chatglm-6b': 2000,
'chatglm2-6b': 32000,
'chatglm3-6b-32k': 32000,
'chatglm3-6b': 8000,
'chatglm2-6b': 8000,
}
max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
)
@classmethod
@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
try:
credential_kwargs = {
'endpoint_url': credentials['api_base']
}
response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
llm = ChatGLM(
max_token=10,
**credential_kwargs
)
llm("ping")
if response.status_code != 200:
raise Exception('ChatGLM Endpoint URL is invalid.')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
quota_limit: int = 0
"""Quota limit for the openai hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedAzureOpenAI(BaseModel):
@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100
class HostedModelProviders(BaseModel):
@ -68,8 +62,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
)
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
@ -85,10 +77,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):

View File

@ -0,0 +1,141 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
from models.provider import ProviderType
class JinaProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'jina'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'jina-embeddings-v2-base-en',
'name': 'jina-embeddings-v2-base-en',
},
{
'id': 'jina-embeddings-v2-small-en',
'name': 'jina-embeddings-v2-small-en',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.EMBEDDINGS:
model_class = JinaEmbedding
else:
raise NotImplementedError
return model_class
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Jina API Key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
embedding = JinaEmbeddings(
model='jina-embeddings-v2-small-en',
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
def _get_text_generation_model_mode(self, model_name) -> str:
raise NotImplementedError
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
raise NotImplementedError

View File

@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):
return False
def get_payment_info(self) -> Optional[dict]:
"""
get payment info if it payable.
:return:
"""
if hosted_model_providers.openai \
and hosted_model_providers.openai.paid_enabled:
return {
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

View File

@ -2,11 +2,13 @@ import json
from typing import Type
import requests
from xinference_client.client.restful.restful_client import Client
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
@ -40,6 +42,8 @@ class XinferenceProvider(BaseModelProvider):
model_class = XinferenceModel
elif model_type == ModelType.EMBEDDINGS:
model_class = XinferenceEmbedding
elif model_type == ModelType.RERANKING:
model_class = XinferenceReranking
else:
raise NotImplementedError
@ -113,6 +117,10 @@ class XinferenceProvider(BaseModelProvider):
)
embedding.embed_query("ping")
elif model_type == ModelType.RERANKING:
rerank_client = Client(credential_kwargs['server_url'])
model = rerank_client.get_model(credential_kwargs['model_uid'])
model.rerank(query="ping", documents=["ping", "pong"], top_n=2)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@ -14,5 +14,6 @@
"xinference",
"openllm",
"localai",
"cohere"
"cohere",
"jina"
]

View File

@ -23,8 +23,14 @@
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"prompt": "8.00",
"completion": "24.00",
"unit": "0.000001",
"currency": "USD"
},
"claude-2.1": {
"prompt": "8.00",
"completion": "24.00",
"unit": "0.000001",
"currency": "USD"
}

View File

@ -21,6 +21,18 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-4-1106-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-vision-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.002",
"completion": "0.0015",

View File

@ -0,0 +1,10 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"supported_model_types": [
"embeddings"
]
}

View File

@ -6,6 +6,7 @@
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
"embeddings",
"reranking"
]
}

View File

@ -40,7 +40,7 @@ default_retrieval_model = {
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
'score_threshold_enabled': False
}
class OrchestratorRuleParser:
@ -207,22 +207,22 @@ class OrchestratorRuleParser:
).first()
if not dataset:
return None
continue
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
continue
dataset_ids.append(dataset.id)
if retrieval_model == 'single':
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
top_k = retrieval_model['top_k']
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
top_k = retrieval_model_config['top_k']
# dynamically adjust top_k when the remaining token number is not enough to support top_k
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_enable = retrieval_model.get("score_threshold_enable")
if score_threshold_enable:
score_threshold = retrieval_model.get("score_threshold")
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
@ -239,7 +239,7 @@ class OrchestratorRuleParser:
dataset_ids=dataset_ids,
tenant_id=kwargs['tenant_id'],
top_k=dataset_configs.get('top_k', 2),
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,

View File

@ -0,0 +1,69 @@
"""Wrapper around Jina embedding models."""
from typing import Any, List
import requests
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
class JinaEmbeddings(BaseModel, Embeddings):
"""Wrapper around Jina embedding models.
"""
client: Any #: :meta private:
api_key: str
model: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
result = self.invoke_embedding(text=text)
embeddings.append(result)
return [list(map(float, e)) for e in embeddings]
def invoke_embedding(self, text):
params = {
"model": self.model,
"input": [
text
]
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
response = requests.post(
'https://api.jina.ai/v1/embeddings',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
json_response = response.json()
return json_response["data"][0]["embedding"]
def embed_query(self, text: str) -> List[float]:
"""Call out to Jina's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@ -1,7 +1,7 @@
from typing import Dict
from httpx import Limits
from langchain.chat_models import ChatAnthropic
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.utils import get_from_dict_or_env, check_package_version
from pydantic import root_validator
@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic):
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],
timeout=values["default_request_timeout"],
max_retries=0,
connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
max_retries=0
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic):
"Please it install it with `pip install anthropic`."
)
return values
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@ -1,11 +1,13 @@
from typing import Dict, Any, Optional, List, Tuple, Union
from typing import Dict, Any, Optional, List, Tuple, Union, cast
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import ChatResult, BaseMessage, ChatGeneration
from pydantic import root_validator
from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [self._convert_message_to_dict(m) for m in messages]
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
return self._create_chat_result(response)
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, LCHumanMessageWithFiles):
content = [
{
"type": "text",
"text": message.content
}
]
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
content.append({
"type": "image_url",
"image_url": {
"url": file.data,
"detail": file.detail.value
}
})
message_dict = {"role": "user", "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

View File

@ -24,7 +24,7 @@ default_retrieval_model = {
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
'score_threshold_enabled': False
}
@ -82,7 +82,8 @@ class DatasetMultiRetrieverTool(BaseTool):
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
@ -192,7 +193,7 @@ class DatasetMultiRetrieverTool(BaseTool):
'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': self.score_threshold,
@ -210,13 +211,13 @@ class DatasetMultiRetrieverTool(BaseTool):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'search_method': 'hybrid_search',
'embeddings': embeddings,
'score_threshold': retrieval_model[
'score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model[
'reranking_model'] if retrieval_model[

View File

@ -25,7 +25,7 @@ default_retrieval_model = {
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
'score_threshold_enabled': False
}
@ -106,11 +106,11 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
@ -124,12 +124,12 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool):
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
self.top_k)
else:
documents = []
@ -158,7 +158,8 @@ class DatasetRetrieverTool(BaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,

View File

@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
else:
return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',

View File

@ -60,7 +60,7 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
def _default_score_normalizer(val: float) -> float:
return 1 - 1 / (1 + np.exp(val))
return 1 - val
def _json_serializable(value: Any) -> Any:
@ -243,7 +243,8 @@ class Weaviate(VectorStore):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_bm25(query=content).with_limit(k).do()
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
@ -380,14 +381,14 @@ class Weaviate(VectorStore):
result = (
query_obj.with_near_vector(vector)
.with_limit(k)
.with_additional("vector")
.with_additional(["vector", "distance"])
.do()
)
else:
result = (
query_obj.with_near_text(content)
.with_limit(k)
.with_additional("vector")
.with_additional(["vector", "distance"])
.do()
)
@ -397,7 +398,7 @@ class Weaviate(VectorStore):
docs_and_scores = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
score = np.dot(res["_additional"]["vector"], embedded_query)
score = res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
return docs_and_scores

View File

@ -1,4 +1,4 @@
from langchain.vectorstores import Weaviate
from core.vector_store.vector.weaviate import Weaviate
class WeaviateVectorStore(Weaviate):

Some files were not shown because too many files have changed in this diff Show More