mirror of
https://github.com/langgenius/dify.git
synced 2026-02-11 13:55:42 +08:00
Compare commits
85 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5bd3b02be6 | |||
| 3cf5c1853d | |||
| a4d86496e1 | |||
| 90bdc85f8c | |||
| 0828873b52 | |||
| 816b707a16 | |||
| c9257ab4bf | |||
| 69ce3b3d33 | |||
| c4caa7c401 | |||
| dc93a292c3 | |||
| 174ee1b646 | |||
| 9b1c4f47fb | |||
| 582ba45c00 | |||
| f1cbd55007 | |||
| 3a34370422 | |||
| 29ab244de6 | |||
| 920b2c2b40 | |||
| ac96d192a6 | |||
| 07fbeb6cf0 | |||
| fc64cdee64 | |||
| 0c0e96c55f | |||
| 5b953c1ef2 | |||
| 562ca45e07 | |||
| 6bbd53512e | |||
| e352a8ed1b | |||
| e55225e2bc | |||
| 3e63abd335 | |||
| 0620fa3094 | |||
| d93288f711 | |||
| ca69af7b97 | |||
| 952e13fef8 | |||
| 4be3087642 | |||
| 49da8a23a8 | |||
| 3ad943a9eb | |||
| 3082093293 | |||
| b03bbab5ad | |||
| 9574730050 | |||
| 91ea6fe4ee | |||
| 769be13189 | |||
| e42175241e | |||
| 12257b438b | |||
| 9ecc736c30 | |||
| 6c4e6bf1d6 | |||
| 97fe817186 | |||
| 52b12ed7eb | |||
| d8ab4474b4 | |||
| 1ecbd95adf | |||
| cad6e6624f | |||
| 3505cbe05c | |||
| e15359e589 | |||
| edb86f5f5a | |||
| adf2651d1f | |||
| 5031d64e28 | |||
| ae3ad59b16 | |||
| e6cd7b0467 | |||
| 97e9f52331 | |||
| 25957d917a | |||
| 20b932da97 | |||
| 207080babc | |||
| 48bacd01cc | |||
| 297d0f1f30 | |||
| eedbe1b770 | |||
| 5ff6b1da07 | |||
| 8b49e0ee2a | |||
| e031ec9359 | |||
| 1bd1cd6938 | |||
| 81c5a21b8d | |||
| 61e4bbabaf | |||
| 4cf475680d | |||
| ca4aa340f6 | |||
| 767d8a4b05 | |||
| 0b8dcaba8f | |||
| af6a318aae | |||
| c6e2900be7 | |||
| 963d9b6032 | |||
| b2ee738bb1 | |||
| c8ca3ff404 | |||
| 5d8fa2c7af | |||
| 58df5e5376 | |||
| 348ad1a624 | |||
| 73e17d5aa8 | |||
| 300d9892a5 | |||
| e47b5b43b8 | |||
| 21c9d9e200 | |||
| 4f6916c4d8 |
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
30
.github/pull_request_template.md
vendored
Normal file
30
.github/pull_request_template.md
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: Run Unit Test For SDKs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
11
README.md
11
README.md
@ -21,6 +21,17 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
|
||||
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
|
||||
</a>
|
||||
<ul align="center" style="text-decoration: none; list-style: none;">
|
||||
<li> US EST: 09:00 (9:00 AM)</li>
|
||||
<li> CET: 15:00 (3:00 PM)</li>
|
||||
<li> CST: 22:00 (10:00 PM)</li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
|
||||
@ -130,3 +130,5 @@ UNSTRUCTURED_API_URL=
|
||||
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
@ -38,10 +38,11 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
|
||||
158
api/commands.py
158
api/commands.py
@ -6,15 +6,15 @@ import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
def vdb_migrate():
|
||||
"""
|
||||
Migrate other vector database datas to Qdrant.
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@ -140,54 +141,101 @@ def create_qdrant_indexes():
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
continue
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"vdb_migrate {dataset.id}")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
vector.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
click.echo(f"Dataset {dataset.id} create successfully.")
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
@ -196,4 +244,4 @@ def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
|
||||
@ -38,7 +38,9 @@ DEFAULTS = {
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
@ -56,6 +58,8 @@ DEFAULTS = {
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
}
|
||||
|
||||
|
||||
@ -86,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.5"
|
||||
self.CURRENT_VERSION = "0.5.7"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -182,7 +186,7 @@ class Config:
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
@ -259,8 +263,10 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@ -285,6 +291,8 @@ class Config:
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@ -16,8 +15,10 @@ language_timezone_mapping = {
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
}
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
@ -26,6 +27,7 @@ def supported_language(lang):
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
@ -67,6 +69,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"ua-UK": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@ -145,7 +157,7 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
@ -272,7 +284,7 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
@ -323,5 +335,130 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
|
||||
"mode": "chat",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
configs={
|
||||
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
"prompt_variables": [],
|
||||
"completion_params": {
|
||||
"max_token": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=None
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
@ -124,19 +124,13 @@ class AppListApi(Resource):
|
||||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
if not default_model_entity:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@ -45,7 +45,8 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file
|
||||
file=file,
|
||||
end_user=None,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -71,7 +72,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -82,10 +83,12 @@ class ChatMessageTextApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@ -112,9 +115,50 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
def get(self, app_id: str):
|
||||
app_model = _get_app(str(app_id))
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
|
||||
raise AppUnavailableError("Text to audio voices language parameter loss.")
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
@ -178,9 +179,9 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
@ -278,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -508,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (
|
||||
@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -424,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
if document.data_source_type == 'upload_file':
|
||||
file_id = data_source_info['upload_file_id']
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == 'notion_import':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@ -85,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model, user_id):
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from models.model import App, AppModelConfig, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -30,8 +30,9 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -85,7 +86,8 @@ class TextApi(AppApiResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
end_user=end_user,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@ -4,12 +4,11 @@ from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class CompletionApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@ -1,52 +1,44 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
class ConversationApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
|
||||
@ -1,30 +1,27 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
class FileApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import EndUser, Message
|
||||
from models.model import App, EndUser
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(AppApiResource):
|
||||
def post(self, app_model, end_user, message_id):
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
|
||||
@ -1,22 +1,40 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
JSON = 'json'
|
||||
FORM = 'form'
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
@ -29,16 +47,35 @@ def validate_app_token(view=None):
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
return view(app_model, None, *args, **kwargs)
|
||||
return decorated
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if view:
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get('user')
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
|
||||
return api_token
|
||||
|
||||
|
||||
class AppApiResource(Resource):
|
||||
method_decorators = [validate_app_token]
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = 'DEFAULT-USER'
|
||||
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@ -68,17 +68,23 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@ -105,7 +111,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -175,7 +175,7 @@ class GenerateTaskPipeline:
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ from core.entities.application_entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity,
|
||||
TextToSpeechEntity,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -572,7 +573,11 @@ class ApplicationManager:
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
properties['text_to_speech'] = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
)
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetQuery, DocumentSegment
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
@ -1,107 +0,0 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import Docx2txtLoader, TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfLoader(BaseLoader):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
upload_file: Optional[UploadFile] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
if self._upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||
+ self._upload_file.hash + '.0625.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
@ -3,12 +3,12 @@ import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
|
||||
@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
"""
|
||||
Prompt Type.
|
||||
@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
App Orchestration Config Entity.
|
||||
@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
text_to_speech: dict = {}
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
@ -45,17 +40,6 @@ class AnnotationReplyFeature:
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_record.tenant_id,
|
||||
provider=embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name
|
||||
)
|
||||
|
||||
# get embedding model
|
||||
embeddings = CacheEmbedding(model_instance)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
@ -71,22 +55,14 @@ class AnnotationReplyFeature:
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
|
||||
documents = vector_index.search(
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional, Union, cast
|
||||
@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
@ -589,4 +588,60 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
return result
|
||||
@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
@ -131,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increase_usage(llm_usage, llm_result.usage)
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = self._handle_stream_react(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
scratchpad.agent_response += json.dumps(chunk)
|
||||
try:
|
||||
if scratchpad.action:
|
||||
raise Exception("")
|
||||
scratchpad.action_str = json.dumps(chunk)
|
||||
scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=chunk['action'],
|
||||
action_input=chunk['action_input']
|
||||
)
|
||||
except:
|
||||
scratchpad.thought += json.dumps(chunk)
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk)
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
llm_usage=usage_dict['usage'])
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
@ -260,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
@ -269,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
@ -316,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
def parse_json(json_str):
|
||||
try:
|
||||
return json.loads(json_str.strip())
|
||||
except:
|
||||
return json_str
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_json(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_json(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_json(json_cache)
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
@ -327,122 +453,40 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
|
||||
def _init_agent_scratchpad(self,
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
messages: list[PromptMessage]
|
||||
) -> list[AgentScratchpadUnit]:
|
||||
"""
|
||||
extract response from llm response
|
||||
init agent scratchpad
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content,
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input is not None:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad.append(current_scratchpad)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
@ -556,15 +600,22 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
overridden = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overrided:
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
@ -104,37 +104,17 @@ class HostingConfiguration:
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=trial_models
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=paid_models
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
@ -258,3 +238,11 @@ class HostingConfiguration:
|
||||
return HostedModerationConfig(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
|
||||
model_name.strip()]
|
||||
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
|
||||
if indexing_technique == "high_quality":
|
||||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif indexing_technique == "economy":
|
||||
return KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=10
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
|
||||
@classmethod
|
||||
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
@ -1,305 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_struct(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
|
||||
if search_type == 'similarity_score_threshold':
|
||||
score_threshold = search_kwargs.get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
search_kwargs['score_threshold'] = .0
|
||||
|
||||
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||
query, **search_kwargs
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, similarity in docs_with_similarity:
|
||||
doc.metadata['score'] = similarity
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
# similarity k
|
||||
# mmr k, fetch_k, lambda_mult
|
||||
# similarity_score_threshold k
|
||||
return vector_store.as_retriever(
|
||||
search_type=search_type,
|
||||
search_kwargs=search_kwargs
|
||||
).get_relevant_documents(query)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.as_retriever(**kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
if kwargs.get('duplicate_check', False):
|
||||
texts = self._filter_duplicate_texts(texts)
|
||||
|
||||
uuids = self._get_uuids(texts)
|
||||
vector_store.add_documents(texts, uuids=uuids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
return False
|
||||
|
||||
def recreate_dataset(self, dataset: Dataset):
|
||||
logging.info(f"Recreating dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
origin_index_struct = self.dataset.index_struct[:]
|
||||
self.dataset.index_struct = None
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
self.dataset.index_struct = origin_index_struct
|
||||
raise e
|
||||
|
||||
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||
|
||||
db.session.commit()
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def create_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"create_qdrant_dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def update_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"update_qdrant_dataset {dataset.id}")
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
try:
|
||||
exist = self.text_exists(segment.index_node_id)
|
||||
if exist:
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
@ -1,165 +0,0 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
connection_args=self._client_config.to_milvus_params(),
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embedding_function=self._embeddings,
|
||||
connection_args=self._client_config.to_milvus_params()
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_document_id(document_id)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_doc_ids(doc_ids)
|
||||
vector_store.del_texts({
|
||||
'filter': f' id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
@ -1,229 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http.models import HnswConfigDiff
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
class QdrantVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
return dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
for node_id in ids:
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=group_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
return vector_store.similarity_search_by_bm25(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
],
|
||||
), kwargs.get('top_k', 2))
|
||||
@ -1,90 +0,0 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
||||
attributes: list) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
|
||||
|
||||
return WeaviateVectorIndex(
|
||||
dataset=dataset,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings,
|
||||
attributes=attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
return QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
|
||||
|
||||
return MilvusVectorIndex(
|
||||
dataset=dataset,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if not self._dataset.index_struct_dict:
|
||||
self._vector_index.create(texts, **kwargs)
|
||||
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
self._vector_index.add_texts(texts, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_index is not None:
|
||||
method = getattr(self._vector_index, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = self._attributes
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return WeaviateVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||
|
||||
@ -9,21 +9,21 @@ from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from core.splitter.text_splitter import TextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -31,7 +31,7 @@ from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class IndexingRunner:
|
||||
@ -56,38 +56,19 @@ class IndexingRunner:
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -133,39 +114,19 @@ class IndexingRunner:
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -219,7 +180,15 @@ class IndexingRunner:
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@ -238,12 +207,20 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
count = len(extract_settings)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
@ -275,16 +252,18 @@ class IndexingRunner:
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for file_detail in file_details:
|
||||
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
all_text_docs.extend(text_docs)
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# load data from file
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
@ -296,7 +275,6 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
total_segments += len(documents)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
@ -355,146 +333,8 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
|
||||
for page in notion_info['pages']:
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page['page_id'],
|
||||
notion_page_type=page['type']
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(documents)
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' and embedding_model_type_instance:
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[document.page_content]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
|
||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||
-> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
@ -510,11 +350,28 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
|
||||
if (not data_source_info or 'notion_workspace_id' not in data_source_info
|
||||
or 'notion_page_id' not in data_source_info):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
@ -528,8 +385,6 @@ class IndexingRunner:
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
# remove invalid symbol
|
||||
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||
text_doc.metadata['document_id'] = dataset_document.id
|
||||
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
|
||||
|
||||
@ -770,12 +625,12 @@ class IndexingRunner:
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
insert index and update document/segment status to completed
|
||||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
@ -808,13 +663,9 @@ class IndexingRunner:
|
||||
)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts(chunk_documents)
|
||||
|
||||
# save keyword index
|
||||
keyword_table_index.add_texts(chunk_documents)
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
db.session.add(dataset)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
@ -894,14 +745,64 @@ class IndexingRunner:
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
text_docs: list[Document], process_rule: dict) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule)
|
||||
|
||||
return documents
|
||||
|
||||
def _load_segments(self, dataset, dataset_document, documents):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=dataset_document.created_by,
|
||||
document_id=dataset_document.id
|
||||
)
|
||||
|
||||
# add document segments
|
||||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.utcnow()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
}
|
||||
)
|
||||
|
||||
# update segment status to indexing
|
||||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
|
||||
@ -99,7 +99,8 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
@ -166,13 +167,15 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
@ -185,9 +188,28 @@ class ModelInstance:
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
def get_tts_voices(self, language: str) -> list:
|
||||
"""
|
||||
Invoke large language tts model voices
|
||||
|
||||
:param language: tts language
|
||||
:return: tts model voices
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
language=language
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||
@ -86,4 +86,4 @@ Model Runtime 分三层:
|
||||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
|
||||
@ -48,6 +48,10 @@
|
||||
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
|
||||
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
|
||||
- `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`)
|
||||
- `voices` (list) List of available voice.(available for model type `tts`)
|
||||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `lanuage` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
|
||||
@ -48,7 +48,11 @@
|
||||
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
|
||||
- `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用)
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `default_voice` (string) 缺省音色,可选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `voices` (list) 可选音色列表。
|
||||
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
|
||||
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
|
||||
- `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用)
|
||||
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
|
||||
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
|
||||
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
|
||||
|
||||
@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
'min': 1,
|
||||
'max': 2048,
|
||||
'precision': 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
'label': {
|
||||
'en_US': 'Response Format',
|
||||
'zh_Hans': '回复格式',
|
||||
},
|
||||
'type': 'string',
|
||||
'help': {
|
||||
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
|
||||
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
|
||||
},
|
||||
'required': False,
|
||||
'options': ['JSON', 'XML'],
|
||||
}
|
||||
}
|
||||
@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||
@ -127,6 +128,7 @@ class ModelPropertyKey(Enum):
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
VOICES = "voices"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDOI_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
|
||||
@ -262,23 +262,23 @@ class AIModel(ABC):
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max:
|
||||
if not parameter_rule.max and 'max' in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule['max']
|
||||
if not parameter_rule.min:
|
||||
if not parameter_rule.min and 'min' in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule['min']
|
||||
if not parameter_rule.precision:
|
||||
if not parameter_rule.default and 'default' in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule['default']
|
||||
if not parameter_rule.precision:
|
||||
if not parameter_rule.precision and 'precision' in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule['precision']
|
||||
if not parameter_rule.required:
|
||||
if not parameter_rule.required and 'required' in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule['required']
|
||||
if not parameter_rule.help:
|
||||
if not parameter_rule.help and 'help' in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule['help']['en_US'],
|
||||
)
|
||||
if not parameter_rule.help.en_US:
|
||||
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
||||
if not parameter_rule.help.zh_Hans:
|
||||
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@ -9,7 +9,13 @@ from typing import Optional, Union
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
if "response_format" in model_parameters:
|
||||
result = self._code_block_mode_wrapper(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
return result
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
))
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
first_chunk = next(response)
|
||||
def new_generator():
|
||||
yield first_chunk
|
||||
yield from response
|
||||
|
||||
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
else:
|
||||
return self._code_block_mode_stream_processor(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:param input_generator: input generator
|
||||
:return: output generator
|
||||
"""
|
||||
state = "normal"
|
||||
backtick_count = 0
|
||||
for piece in input_generator:
|
||||
if piece.delta.message.content:
|
||||
content = piece.delta.message.content
|
||||
piece.delta.message.content = ""
|
||||
yield piece
|
||||
piece = content
|
||||
else:
|
||||
yield piece
|
||||
continue
|
||||
new_piece = ""
|
||||
for char in piece:
|
||||
if state == "normal":
|
||||
if char == "`":
|
||||
state = "in_backticks"
|
||||
backtick_count = 1
|
||||
else:
|
||||
new_piece += char
|
||||
elif state == "in_backticks":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "skip_content"
|
||||
backtick_count = 0
|
||||
else:
|
||||
new_piece += "`" * backtick_count + char
|
||||
state = "normal"
|
||||
backtick_count = 0
|
||||
elif state == "skip_content":
|
||||
if char.isspace():
|
||||
state = "normal"
|
||||
|
||||
if new_piece:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
||||
This version skips the language identifier that follows the opening triple backticks.
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:param input_generator: input generator
|
||||
:return: output generator
|
||||
"""
|
||||
state = "search_start"
|
||||
backtick_count = 0
|
||||
|
||||
for piece in input_generator:
|
||||
if piece.delta.message.content:
|
||||
content = piece.delta.message.content
|
||||
# Reset content to ensure we're only processing and yielding the relevant parts
|
||||
piece.delta.message.content = ""
|
||||
# Yield a piece with cleared content before processing it to maintain the generator structure
|
||||
yield piece
|
||||
piece = content
|
||||
else:
|
||||
# Yield pieces without content directly
|
||||
yield piece
|
||||
continue
|
||||
|
||||
if state == "done":
|
||||
continue
|
||||
|
||||
new_piece = ""
|
||||
for char in piece:
|
||||
if state == "search_start":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "skip_language"
|
||||
backtick_count = 0
|
||||
else:
|
||||
backtick_count = 0
|
||||
elif state == "skip_language":
|
||||
# Skip everything until the first newline, marking the end of the language identifier
|
||||
if char == "\n":
|
||||
state = "in_code_block"
|
||||
elif state == "in_code_block":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "done"
|
||||
break
|
||||
else:
|
||||
if backtick_count > 0:
|
||||
# If backticks were counted but we're still collecting content, it was a false start
|
||||
new_piece += "`" * backtick_count
|
||||
backtick_count = 0
|
||||
new_piece += char
|
||||
|
||||
elif state == "done":
|
||||
break
|
||||
|
||||
if new_piece:
|
||||
# Only yield content collected within the code block
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
@ -15,29 +15,37 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text)
|
||||
self._is_ffmpeg_installed()
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
@ -45,7 +53,25 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_model_voice(self, model: str, credentials: dict) -> any:
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Get voice for given tts model voices
|
||||
|
||||
:param language: tts language
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voices lists
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')]
|
||||
else:
|
||||
return [{'name': d['name'], 'value': d['mode']} for d in voices]
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
"""
|
||||
Get voice for given tts model
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
- bedrock
|
||||
- togetherai
|
||||
- ollama
|
||||
- mistralai
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- zhipuai
|
||||
|
||||
@ -27,6 +27,8 @@ parameter_rules:
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
|
||||
@ -27,6 +27,8 @@ parameter_rules:
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
|
||||
@ -26,6 +26,8 @@ parameter_rules:
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.63'
|
||||
output: '5.51'
|
||||
|
||||
@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
|
||||
from anthropic.types import Completion, completion_create_params
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if 'response_format' in model_parameters and model_parameters['response_format']:
|
||||
stop = stop or []
|
||||
self._transform_json_prompts(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=f"```{response_format}\n"
|
||||
))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
@ -128,8 +128,10 @@ class BaichuanModel:
|
||||
'role': message.role,
|
||||
})
|
||||
# [baichuan] frequency_penalty must be between 1 and 2
|
||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
||||
parameters['frequency_penalty'] = 1
|
||||
if 'frequency_penalty' in parameters:
|
||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
||||
parameters['frequency_penalty'] = 1
|
||||
|
||||
# turbo api accepts flat parameters
|
||||
return {
|
||||
'model': self._model_mapping(model),
|
||||
|
||||
@ -103,7 +103,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
], parameters={
|
||||
'max_tokens': 1,
|
||||
}, timeout=60)
|
||||
except (InvalidAPIKeyError, InvalidAuthenticationError) as e:
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
|
||||
@ -27,6 +27,8 @@ parameter_rules:
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 6.9 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 2.3 KiB |
@ -0,0 +1,5 @@
|
||||
- open-mistral-7b
|
||||
- open-mixtral-8x7b
|
||||
- mistral-small-latest
|
||||
- mistral-medium-latest
|
||||
- mistral-large-latest
|
||||
31
api/core/model_runtime/model_providers/mistralai/llm/llm.py
Normal file
31
api/core/model_runtime/model_providers/mistralai/llm/llm.py
Normal file
@ -0,0 +1,31 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
# mistral dose not support user/stop arguments
|
||||
stop = []
|
||||
user = None
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
|
||||
@ -0,0 +1,50 @@
|
||||
model: mistral-large-latest
|
||||
label:
|
||||
zh_Hans: mistral-large-latest
|
||||
en_US: mistral-large-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,50 @@
|
||||
model: mistral-medium-latest
|
||||
label:
|
||||
zh_Hans: mistral-medium-latest
|
||||
en_US: mistral-medium-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.0027'
|
||||
output: '0.0081'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,50 @@
|
||||
model: mistral-small-latest
|
||||
label:
|
||||
zh_Hans: mistral-small-latest
|
||||
en_US: mistral-small-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,50 @@
|
||||
model: open-mistral-7b
|
||||
label:
|
||||
zh_Hans: open-mistral-7b
|
||||
en_US: open-mistral-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00025'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,50 @@
|
||||
model: open-mixtral-8x7b
|
||||
label:
|
||||
zh_Hans: open-mixtral-8x7b
|
||||
en_US: open-mixtral-8x7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.0007'
|
||||
output: '0.0007'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='open-mistral-7b',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,31 @@
|
||||
provider: mistralai
|
||||
label:
|
||||
en_US: MistralAI
|
||||
description:
|
||||
en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
|
||||
zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from MistralAI
|
||||
zh_Hans: 从 MistralAI 获取 API Key
|
||||
url:
|
||||
en_US: https://console.mistral.ai/api-keys/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@ -24,6 +24,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0015'
|
||||
|
||||
@ -24,6 +24,8 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.0015'
|
||||
output: '0.002'
|
||||
|
||||
@ -24,6 +24,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.002'
|
||||
|
||||
@ -24,6 +24,8 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16385
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.004'
|
||||
|
||||
@ -24,6 +24,8 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16385
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.004'
|
||||
|
||||
@ -21,6 +21,8 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.0015'
|
||||
output: '0.002'
|
||||
|
||||
@ -24,6 +24,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.002'
|
||||
|
||||
@ -37,9 +37,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -37,9 +37,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -37,9 +37,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -37,9 +37,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -35,9 +35,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -37,9 +37,6 @@ parameter_rules:
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
precision: 2
|
||||
min: 0
|
||||
max: 1
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
"""
|
||||
@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
user=user
|
||||
)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
# handle fine tune remote models
|
||||
base_model = model
|
||||
if model.startswith('ft:'):
|
||||
base_model = model.split(':')[1]
|
||||
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(base_model, credentials)
|
||||
|
||||
# transform response format
|
||||
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||
stop = stop or []
|
||||
if model_mode == LLMMode.CHAT:
|
||||
# chat model
|
||||
self._transform_chat_json_prompts(
|
||||
model=base_model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
)
|
||||
else:
|
||||
self._transform_completion_json_prompts(
|
||||
model=base_model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
|
||||
def _transform_completion_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
# override the last user message
|
||||
user_message = None
|
||||
for i in range(len(prompt_messages) - 1, -1, -1):
|
||||
if isinstance(prompt_messages[i], UserPromptMessage):
|
||||
user_message = prompt_messages[i]
|
||||
break
|
||||
|
||||
if user_message:
|
||||
if prompt_messages[i].content[-11:] == 'Assistant: ':
|
||||
# now we are in the chat app, remove the last assistant message
|
||||
prompt_messages[i].content = prompt_messages[i].content[:-11]
|
||||
prompt_messages[i] = UserPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", user_message.content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
|
||||
else:
|
||||
prompt_messages[i] = UserPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", user_message.content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages[i].content += f"\n```{response_format}\n"
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
||||
@ -2,6 +2,30 @@ model: tts-1-hd
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'alloy'
|
||||
voices:
|
||||
- mode: 'alloy'
|
||||
name: 'Alloy'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
- mode: 'echo'
|
||||
name: 'Echo'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
- mode: 'fable'
|
||||
name: 'Fable'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
- mode: 'onyx'
|
||||
name: 'Onyx'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
- mode: 'nova'
|
||||
name: 'Nova'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 120
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
input: '0.03'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
||||
@ -2,6 +2,30 @@ model: tts-1
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'alloy'
|
||||
voices:
|
||||
- mode: 'alloy'
|
||||
name: 'Alloy'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
- mode: 'echo'
|
||||
name: 'Echo'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
- mode: 'fable'
|
||||
name: 'Fable'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
- mode: 'onyx'
|
||||
name: 'Onyx'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
- mode: 'nova'
|
||||
name: 'Nova'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
word_limit: 120
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
||||
@ -11,33 +11,40 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
self._is_ffmpeg_installed()
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
user=user)),
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user)
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@ -52,91 +59,96 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
self._tts_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello world!',
|
||||
user=user
|
||||
content_text='Hello Dify!',
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> Response:
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(self._process_sentence, sentence, model, credentials) for sentence
|
||||
in sentences]
|
||||
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
|
||||
credentials=credentials) for sentence in sentences]
|
||||
for future in futures:
|
||||
try:
|
||||
audio_bytes_list.append(future.result())
|
||||
if future.result():
|
||||
audio_bytes_list.append(future.result())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
if len(audio_bytes_list) > 0:
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'storage/generate_files/{audio_type}/{tts_file_id}.{audio_type}'
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
|
||||
response.stream_to_file(file_path)
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str, credentials: dict):
|
||||
def _process_sentence(self, sentence: str, model: str,
|
||||
voice, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param sentence: text content to be translated
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
if isinstance(response.read(), bytes):
|
||||
return response.read()
|
||||
|
||||
@ -367,13 +367,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
if chunk:
|
||||
#ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}")
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
|
||||
@ -13,6 +13,7 @@ from dashscope.common.error import (
|
||||
)
|
||||
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
Wrapper for code block mode
|
||||
"""
|
||||
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
))
|
||||
|
||||
mode = self.get_model_mode(model, credentials)
|
||||
if mode == LLMMode.CHAT:
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
params = {
|
||||
'model': model,
|
||||
**model_parameters,
|
||||
**credentials_kwargs
|
||||
**credentials_kwargs,
|
||||
**extra_model_kwargs,
|
||||
}
|
||||
|
||||
mode = self.get_model_mode(model, credentials)
|
||||
|
||||
@ -57,3 +57,5 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
||||
@ -57,3 +57,5 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
||||
@ -57,3 +57,5 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user