Compare commits

..

69 Commits
0.5.2 ... 0.5.3

Author SHA1 Message Date
9f637ead38 bump version to 0.5.3 (#2306) 2024-02-01 18:11:57 +08:00
b521aafd26 chore(web): strong typing (#2339) 2024-02-01 18:07:26 +08:00
a84e15b8cc fix: ignore spark provider credential validate (#2344) 2024-02-01 18:04:05 +08:00
0c330fc020 feat: add xinference llm context size (#2336) 2024-02-01 17:10:45 +08:00
cfbb7bec58 Feat/current time tool zone (#2337) 2024-02-01 17:09:59 +08:00
3b357f51a6 fix: first agent latency (#2334) 2024-02-01 15:30:50 +08:00
09acf215f0 add option to prompt for a validation password when initializing admin user (#2302) 2024-02-01 15:03:56 +08:00
07dd8b94ed fix: check empty tool provider credentials (#2332) 2024-02-01 13:13:28 +08:00
ef308fd121 feat: add sd model parameter (#2331) 2024-02-01 13:12:57 +08:00
fce64d760b fix: empty model features (#2330) 2024-02-01 13:11:11 +08:00
f0c9bb7c91 fix: typo (#2318) 2024-02-01 13:08:31 +08:00
d8672796b0 revert: remove unused session store codes (#2329) 2024-02-01 12:10:05 +08:00
5929e84036 Optimization stable diffusion verify (#2322)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-01 12:05:09 +08:00
83063532a0 Fix/api tool (#2317) 2024-02-01 09:10:32 +08:00
07279558a5 Change ZHIPU_MAX_LIMITS to 5. Fix issue 2323 (#2324) 2024-02-01 09:06:32 +08:00
2166473852 Feat/add spark3.5 llm (#2314)
Co-authored-by: lux@njuelectronics.com <lux@njuelectronics.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-31 17:57:17 +08:00
44397e3062 remove unused session store codes (#2313) 2024-01-31 15:30:35 +08:00
883a0a0e6a chore: detect is function calling from model config (#2312) 2024-01-31 14:06:27 +08:00
b5ed81b349 fix: invalid server tool url caused crash (#2311) 2024-01-31 14:04:54 +08:00
625b0afa52 fix: next public edition default value (#2310) 2024-01-31 12:32:13 +08:00
2660fbaa20 Fix/typos (#2308) 2024-01-31 11:58:07 +08:00
9e37702d24 feat: ui improvements for Portuguese (#2304) 2024-01-31 11:25:33 +08:00
bc11c6a7f2 feat: recommended apps list support sort by position (#2303) 2024-01-31 11:00:44 +08:00
10e9766fd3 chore:azure dalle tool support pt-BR text (#2301)
Co-authored-by: lux@njuelectronics.com <lux@njuelectronics.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-30 23:49:19 +08:00
6d24a2cb87 fix: api tool encoding (#2296) 2024-01-30 22:22:58 +08:00
0a4dfaeaf9 Feat: Add Top bar while routing different different pages (#2298) 2024-01-30 20:22:17 +08:00
c0a4fd145c Add custom tools (#2299)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-01-30 19:59:22 +08:00
70f16e1a0b fix: keep original tool credentials (#2288) 2024-01-30 18:41:36 +08:00
cb27571e9f fix: missing prompt (#2294) 2024-01-30 17:00:50 +08:00
0518da5819 remove repositories tool (#2293) 2024-01-30 16:51:36 +08:00
d2797abdb4 Add custom tools (#2292)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-01-30 16:33:49 +08:00
bf3ee660e0 fix: missing files (#2291) 2024-01-30 16:21:40 +08:00
68406b9906 fix: multiple model configuration clear conversation by rerender (#2286) 2024-01-30 16:06:01 +08:00
6f7fd6613a feat: file icon support doc and docx (#2289) 2024-01-30 15:55:07 +08:00
6d5b386394 Feat/blocking function call (#2247) 2024-01-30 15:25:37 +08:00
1ea18a2922 feat: optimize tool name (#2284) 2024-01-30 14:58:59 +08:00
f8f4b961a1 chore: handle app name and options too long (#2283) 2024-01-30 14:53:10 +08:00
57565db531 feat: some unused command-line tasks were removed. (#2281) 2024-01-30 14:33:48 +08:00
d844420c07 bump flask from 2.3 to 3.0 (#2279) 2024-01-30 13:35:13 +08:00
34634bddf1 fix: setting default model to gpt-3.5-turbo-1106 and remove default m… (#2274) 2024-01-30 13:04:17 +08:00
c97b7f6748 Feat/add azure dalle tool (#2276)
Co-authored-by: lux@njuelectronics.com <lux@njuelectronics.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-30 11:38:58 +08:00
76cc19f525 Add custom tools (#2259)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-01-30 11:03:20 +08:00
5baaebb3fd fix: typo of builtin tools (#2275) 2024-01-30 08:09:31 +08:00
9d072920da fix: remove finish_reason condition logic when deltaContent is empty (#2270)
Co-authored-by: wanggang <wanggy01@servyou.com.cn>
2024-01-29 23:24:13 +08:00
965ca36525 use pm2 to guard and monitor the web service in docker file (#2238) 2024-01-29 18:21:15 +08:00
b4988ce20c fix: missing keys language in parser (#2271) 2024-01-29 17:59:59 +08:00
d3d617239f Feat/utm update (#2269)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-01-29 17:31:45 +08:00
6c3b34a61d chore: update price page (#2272) 2024-01-29 17:26:43 +08:00
d76d1adb59 feat: Nodejs sdk support auto rename conversation api (#2265) 2024-01-29 12:57:39 +08:00
cadc6b171e chore: change expert mode the same line height as automatic (#2263) 2024-01-29 11:10:19 +08:00
fdae2a20ae fix: stop generate api doc error (#2262) 2024-01-29 11:10:07 +08:00
45701a81e9 fix: initial paragraph can not input more than 48 chars (#2258) 2024-01-29 09:58:29 +08:00
409e0c8e1c update qdrant migrate command (#2260)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-28 19:59:06 +08:00
7076d41b29 Bugfix/invitemailmultilangs (#2257)
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-28 19:56:09 +08:00
5a6cb69951 fix: user handling in stop api (#2254) 2024-01-27 19:05:37 +08:00
11a75ee78a fix: remove invalid parameter return_type (#2253) 2024-01-27 14:29:25 +08:00
b9b692d71d fix typo (#2248) 2024-01-27 03:56:23 +08:00
d8f8afcbd0 fix: Resolved the issue of duplicate display of supported file types during text file upload (#2241)
Co-authored-by: hbc <hbc@hbc-iMac.local>
2024-01-26 19:44:49 +08:00
8cb62ef31a Maintenance notice href (#2234)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-26 19:14:39 +08:00
bb5d5fc683 Feat/billing enhancement (#2239)
Co-authored-by: takatost <takatost@gmail.com>
2024-01-26 18:26:15 +08:00
2fc0dcc10a feat: team admin can pay billing (#2240) 2024-01-26 18:06:54 +08:00
9fd55157d6 fix: vision config (#2235) 2024-01-26 17:12:16 +08:00
6c384dba71 fix: register ga id error (#2237) 2024-01-26 17:11:52 +08:00
9730297381 chore: move register ga to signin page (#2233) 2024-01-26 15:50:14 +08:00
99e80a8ed0 fix:Bedrock llm issue #2214 (#2215)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
2024-01-26 15:34:29 +08:00
26fef2d481 Maintenance notice href (#2228)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-26 15:28:33 +08:00
c9e65f4221 Fix/update broken doc links (#2187)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-01-26 15:20:03 +08:00
20bd33fada feat: prompt IDE support change height (#2232) 2024-01-26 15:13:06 +08:00
bd0af2e921 fix: occasional multiple responses displayed in frontend due to unexpected message_id from onData (#2231) 2024-01-26 15:08:37 +08:00
343 changed files with 5066 additions and 3717 deletions

View File

@ -30,7 +30,7 @@ from flask import Flask, Response, request
from flask_cors import CORS
from libs.passport import PassportService
# DO NOT REMOVE BELOW
from models import account, dataset, model, source, task, tool, web, tools
from models import account, dataset, model, source, task, tool, tools, web
from services.account_service import AccountService
# DO NOT REMOVE ABOVE

View File

@ -1,32 +1,20 @@
import base64
import datetime
import json
import math
import random
import secrets
import string
import threading
import time
import uuid
import click
import qdrant_client
from constants.languages import user_input_form_template
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from flask import Flask, current_app
from flask import current_app
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document
from models.model import Account, App, AppModelConfig, Message, MessageAnnotation, InstalledApp
from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from models.account import Tenant
from models.dataset import Dataset
from models.model import Account
from models.provider import Provider, ProviderModel
from werkzeug.exceptions import NotFound
@ -35,15 +23,22 @@ from werkzeug.exceptions import NotFound
@click.option('--new-password', prompt=True, help='the new password.')
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
return
try:
valid_password(new_password)
except:
@ -69,15 +64,22 @@ def reset_password(email, new_password, password_confirm):
@click.option('--new-email', prompt=True, help='the new email.')
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
return
try:
email_validate(new_email)
except:
@ -97,6 +99,11 @@ def reset_email(email, new_email, email_confirm):
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
' this operation cannot be rolled back!', fg='red'))
def reset_encrypt_key_pair():
"""
Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if current_app.config['EDITION'] != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
return
@ -116,201 +123,11 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('generate-invitation-codes', help='Generate invitation codes.')
@click.option('--batch', help='The batch of invitation codes.')
@click.option('--count', prompt=True, help='Invitation codes count.')
def generate_invitation_codes(batch, count):
if not batch:
now = datetime.datetime.now()
batch = now.strftime('%Y%m%d%H%M%S')
if not count or int(count) <= 0:
click.echo(click.style('sorry. the count must be greater than 0.', fg='red'))
return
count = int(count)
click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch))
codes = ''
for i in range(count):
code = generate_invitation_code()
invitation_code = InvitationCode(
code=code,
batch=batch
)
db.session.add(invitation_code)
click.echo(code)
codes += code + "\n"
db.session.commit()
filename = 'storage/invitation-codes-{}.txt'.format(batch)
with open(filename, 'w') as f:
f.write(codes)
click.echo(click.style(
'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch,
filename),
fg='green'))
def generate_invitation_code():
code = generate_upper_string()
while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0:
code = generate_upper_string()
return code
def generate_upper_string():
letters_digits = string.ascii_uppercase + string.digits
result = ""
for i in range(8):
result += random.choice(letters_digits)
return result
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
def recreate_all_dataset_indexes():
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
recreate_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Recreating dataset index: {}'.format(dataset.id))
index = IndexBuilder.get_index(dataset, 'high_quality')
if index and index._is_origin():
index.recreate_dataset(dataset)
recreate_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
def clean_unused_dataset_indexes():
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
dataset_query = db.session.query(DatasetQuery).filter(
DatasetQuery.created_at > thirty_days_ago,
DatasetQuery.dataset_id == dataset.id
).all()
if not dataset_query or len(dataset_query) == 0:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset.id,
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.updated_at > thirty_days_ago
).all()
if not documents or len(documents) == 0:
try:
# remove index
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
Document.enabled: False
}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
fg='green'))
except Exception as e:
click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
end_at = time.perf_counter()
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
if not hosted_model_providers.anthropic:
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
return
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
new_quota_limit = hosted_model_providers.anthropic.quota_limit
page = 1
while True:
try:
providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_limit != new_quota_limit
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
page += 1
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
original_quota_limit = provider.quota_limit
division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
else original_quota_limit * division
provider.quota_used = division * provider.quota_used
db.session.commit()
count += 1
except Exception as e:
click.echo(click.style(
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
"""
Migrate other vector database datas to Qdrant.
"""
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
@ -339,26 +156,7 @@ def create_qdrant_indexes():
)
except Exception:
try:
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
continue
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
@ -393,380 +191,8 @@ def create_qdrant_indexes():
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
def update_qdrant_indexes():
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.update_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = []
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
datasets_result = datasets.items
page += 1
for i in range(0, len(datasets_result), 5):
threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
def add_qdrant_full_text_index():
click.echo(click.style('Start add full text index.', fg='green'))
binds = db.session.query(DatasetCollectionBinding).all()
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
qdrant_url = current_app.config['QDRANT_URL']
qdrant_api_key = current_app.config['QDRANT_API_KEY']
client = qdrant_client.QdrantClient(
qdrant_url,
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
)
for bind in binds:
try:
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
client.create_payload_index(bind.collection_name, 'page_content',
field_schema=text_index_params)
except Exception as e:
click.echo(
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
pre_prompt_template = '{{default_input}}'
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
if data.pre_prompt is None:
data.pre_prompt = pre_prompt_template
else:
if pre_prompt_template in data.pre_prompt:
continue
data.pre_prompt += pre_prompt_template
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
.one_or_none()
if not account_data:
continue
if data.user_input_form is None or data.user_input_form == 'null':
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
else:
raw_json_data = json.loads(data.user_input_form)
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
data.user_input_form = json.dumps(raw_json_data)
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
config = AppModelConfig.to_dict(data)
tools = config["agent_mode"]["tools"]
dataset_exists = "dataset" in str(tools)
if not dataset_exists:
continue
user_input_form = config.get("user_input_form", [])
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)

View File

@ -40,17 +40,11 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_OPENAI_PAID_MIN_QUANTITY': 1,
'HOSTED_OPENAI_PAID_MAX_QUANTITY': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 1,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 1,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'CLEAN_DAY_SETTING': 30,
@ -93,7 +87,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.2"
self.CURRENT_VERSION = "0.5.3"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -262,10 +256,6 @@ class Config:
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_OPENAI_PAID_MIN_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MIN_QUANTITY'))
self.HOSTED_OPENAI_PAID_MAX_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MAX_QUANTITY'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@ -277,10 +267,6 @@ class Config:
self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED')
self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED')

View File

@ -1,5 +1,6 @@
import json
from models.model import AppModelConfig
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']

View File

@ -11,13 +11,11 @@ from .app import (advanced_prompt_template, annotation, app, audio, completion,
model_config, site, statistic)
# Import auth controllers
from .auth import activate, data_source_oauth, login, oauth
# Import billing controllers
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
# Import explore controllers
from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message
# Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace
# Import billing controllers
from .billing import billing
# Import operation controllers
from .operation import operation

View File

@ -1,12 +1,12 @@
import os
from functools import wraps
from constants.languages import supported_language
from controllers.console import api
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from flask import request
from flask_restful import Resource, reqparse
from constants.languages import supported_language
from models.model import App, InstalledApp, RecommendedApp
from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -3,8 +3,8 @@ import json
import logging
from datetime import datetime
from constants.model_template import model_templates
from constants.languages import demo_model_templates, languages
from constants.model_template import model_templates
from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
@ -26,6 +26,7 @@ from models.tools import ApiToolProvider
from services.app_model_config_service import AppModelConfigService
from werkzeug.exceptions import Forbidden
def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
if not app:
@ -107,20 +108,33 @@ class AppListApi(Resource):
# validate config
model_config_dict = args['model_config']
# get model provider
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
# Get provider configurations
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id)
# get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=ModelType.LLM,
only_active=True
)
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
# check if model is available
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,

View File

@ -1,4 +1,5 @@
# -*- coding:utf-8 -*-
from constants.languages import supported_language
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
@ -7,7 +8,6 @@ from extensions.ext_database import db
from fields.app_fields import app_site_fields
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from constants.languages import supported_language
from libs.login import login_required
from models.model import Site
from werkzeug.exceptions import Forbidden, NotFound

View File

@ -2,12 +2,12 @@ import base64
import secrets
from datetime import datetime
from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from flask_restful import Resource, reqparse
from libs.helper import email, str_len, timezone
from constants.languages import supported_language
from libs.password import hash_password, valid_password
from models.account import AccountStatus, Tenant
from services.account_service import RegisterService

View File

@ -20,7 +20,7 @@ class Subscription(Resource):
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
BillingService.is_tenant_owner(current_user)
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args['plan'],
args['interval'],
@ -35,8 +35,8 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
BillingService.is_tenant_owner(current_user)
return BillingService.get_invoices(current_user.email)
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, '/billing/subscription')

View File

@ -9,7 +9,7 @@ from flask import current_app, request
from flask_login import current_user
from flask_restful import Resource, marshal_with
from libs.login import login_required
from services.file_service import FileService, ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
PREVIEW_WORDS_LIMIT = 3000

View File

@ -13,6 +13,16 @@ class NotSetupError(BaseHTTPException):
"Please proceed with the initialization and installation process first."
code = 401
class NotInitValidateError(BaseHTTPException):
error_code = 'not_init_validated'
description = "Init validation has not been completed yet. " \
"Please proceed with the init validation process first."
code = 401
class InitValidateFailedError(BaseHTTPException):
error_code = 'init_validate_failed'
description = "Init validation failed. Please check the password and try again."
code = 401
class AccountNotLinkTenantError(BaseHTTPException):
error_code = 'account_not_link_tenant'

View File

@ -17,9 +17,9 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import marshal_with, reqparse, fields
from flask_restful import fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from libs.helper import uuid_value, TimestampField
from libs.helper import TimestampField, uuid_value
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError

View File

@ -3,12 +3,12 @@ import json
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from extensions.ext_database import db
from flask import current_app
from flask_restful import fields, marshal_with
from models.model import InstalledApp, AppModelConfig
from models.model import AppModelConfig, InstalledApp
from models.tools import ApiToolProvider
from extensions.ext_database import db
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""

View File

@ -1,4 +1,5 @@
# -*- coding:utf-8 -*-
from constants.languages import languages
from controllers.console import api
from controllers.console.app.error import AppNotFoundError
from controllers.console.wraps import account_initialization_required
@ -9,7 +10,6 @@ from libs.login import login_required
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from sqlalchemy import and_
from constants.languages import languages
app_fields = {
'id': fields.String,

View File

@ -3,10 +3,12 @@ from flask_restful import Resource
from services.feature_service import FeatureService
from . import api
from .wraps import cloud_utm_record
class FeatureApi(Resource):
@cloud_utm_record
def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict()

View File

@ -0,0 +1,48 @@
import os
from flask import current_app, session
from flask_restful import Resource, reqparse
from libs.helper import str_len
from models.model import DifySetup
from services.account_service import TenantService
from . import api
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
class InitValidateAPI(Resource):
def get(self):
init_status = get_init_validate_status()
if init_status:
return { 'status': 'finished' }
return {'status': 'not_started' }
@only_edition_self_hosted
def post(self):
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument('password', type=str_len(30),
required=True, location='json')
input_password = parser.parse_args()['password']
if input_password != os.environ.get('INIT_PASSWORD'):
session['is_init_validated'] = False
raise InitValidateFailedError()
session['is_init_validated'] = True
return {'result': 'success'}, 201
def get_init_validate_status():
if current_app.config['EDITION'] == 'SELF_HOSTED':
if os.environ.get('INIT_PASSWORD'):
return session.get('is_init_validated') or DifySetup.query.first()
return True
api.add_resource(InitValidateAPI, '/init')

View File

@ -1,30 +0,0 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, only_edition_cloud
from libs.login import login_required
from services.operation_service import OperationService
class TenantUtm(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('utm_source', type=str, required=True)
parser.add_argument('utm_medium', type=str, required=True)
parser.add_argument('utm_campaign', type=str, required=False, default='')
parser.add_argument('utm_content', type=str, required=False, default='')
parser.add_argument('utm_term', type=str, required=False, default='')
args = parser.parse_args()
return OperationService.record_utm(current_user.current_tenant_id, args)
api.add_resource(TenantUtm, '/operation/utm')

View File

@ -10,7 +10,8 @@ from models.model import DifySetup
from services.account_service import AccountService, RegisterService, TenantService
from . import api
from .error import AlreadySetupError, NotSetupError
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
@ -24,7 +25,7 @@ class SetupApi(Resource):
'step': 'finished',
'setup_at': setup_status.setup_at.isoformat()
}
return {'step': 'not_start'}
return {'step': 'not_started'}
return {'step': 'finished'}
@only_edition_self_hosted
@ -37,6 +38,9 @@ class SetupApi(Resource):
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
parser = reqparse.RequestParser()
parser.add_argument('email', type=email,
@ -71,7 +75,10 @@ def setup_required(view):
@wraps(view)
def decorated(*args, **kwargs):
# check setup
if not get_setup_status():
if not get_init_validate_status():
raise NotInitValidateError()
elif not get_setup_status():
raise NotSetupError()
return view(*args, **kwargs)

View File

@ -2,6 +2,7 @@
from datetime import datetime
import pytz
from constants.languages import supported_language
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError,
@ -12,7 +13,6 @@ from flask import current_app, request
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
from libs.helper import TimestampField, timezone
from constants.languages import supported_language
from libs.login import login_required
from models.account import AccountIntegrate, InvitationCode
from services.account_service import AccountService

View File

@ -1,13 +1,12 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import current_user
from flask_restful import Resource, abort, fields, marshal_with, reqparse
import services
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from flask import current_app
from flask_login import current_user
from flask_restful import Resource, abort, fields, marshal_with, reqparse
from libs.helper import TimestampField
from libs.login import login_required
from models.account import Account
@ -52,10 +51,12 @@ class MemberInviteEmailApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('emails', type=str, required=True, location='json', action='append')
parser.add_argument('role', type=str, required=True, default='admin', location='json')
parser.add_argument('language', type=str, required=False, location='json')
args = parser.parse_args()
invitee_emails = args['emails']
invitee_role = args['role']
interface_language = args['language']
if invitee_role not in ['admin', 'normal']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
@ -64,8 +65,7 @@ class MemberInviteEmailApi(Resource):
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
for invitee_email in invitee_emails:
try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
inviter=inviter)
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
invitation_results.append({
'status': 'success',
'email': invitee_email,

View File

@ -186,10 +186,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != 'anthropic':
raise ValueError(f'provider name {provider} is invalid')
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
account_id=current_user.id,
prefilled_email=current_user.email)
return data

View File

@ -1,18 +1,16 @@
import io
import json
from libs.login import login_required
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask import send_file
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from libs.login import login_required
from services.tools_manage_service import ToolManageService
from werkzeug.exceptions import Forbidden
import io
class ToolProviderListApi(Resource):
@setup_required
@ -171,8 +169,8 @@ class ToolApiProviderUpdateApi(Resource):
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=str, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
args = parser.parse_args()

View File

@ -1,10 +1,12 @@
# -*- coding:utf-8 -*-
import json
from functools import wraps
from controllers.console.workspace.error import AccountNotInitializedError
from flask import abort, current_app
from flask import abort, current_app, request
from flask_login import current_user
from services.feature_service import FeatureService
from services.operation_service import OperationService
def account_initialization_required(view):
@ -73,3 +75,20 @@ def cloud_edition_billing_resource_check(resource: str,
return decorated
return interceptor
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
try:
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get('utm_info')
if utm_info:
utm_info = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info)
except Exception as e:
pass
return view(*args, **kwargs)
return decorated

View File

@ -6,5 +6,4 @@ bp = Blueprint('files', __name__)
api = ExternalApi(bp)
from . import image_preview
from . import tool_files
from . import image_preview, tool_files

View File

@ -1,10 +1,10 @@
from controllers.files import api
from core.tools.tool_file_manager import ToolFileManager
from flask import Response
from flask_restful import Resource, reqparse
from libs.exception import BaseHTTPException
from werkzeug.exceptions import NotFound, Forbidden
from werkzeug.exceptions import Forbidden, NotFound
from core.tools.tool_file_manager import ToolFileManager
class ToolFilePreviewApi(Resource):
def get(self, file_id, extension):

View File

@ -1,15 +1,14 @@
# -*- coding:utf-8 -*-
import json
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from flask import current_app
from flask_restful import fields, marshal_with
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
import json
from extensions.ext_database import db
class AppParameterApi(AppApiResource):
"""Resource for app variables."""

View File

@ -13,7 +13,7 @@ from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from flask import Response, stream_with_context, request
from flask import Response, stream_with_context
from flask_restful import reqparse
from libs.helper import uuid_value
from services.completion_service import CompletionService
@ -75,18 +75,22 @@ class CompletionApi(AppApiResource):
class CompletionStopApi(AppApiResource):
def post(self, app_model, _, task_id):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
end_user_id = args.get('user')
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
@ -146,13 +150,22 @@ class ChatApi(AppApiResource):
class ChatStopApi(AppApiResource):
def post(self, app_model, _, task_id):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
end_user_id = request.get_json().get('user')
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200

View File

@ -1,4 +1,3 @@
from models.dataset import Dataset
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
@ -9,6 +8,7 @@ from fields.dataset_fields import dataset_detail_fields
from flask import request
from flask_restful import marshal, reqparse
from libs.login import current_user
from models.dataset import Dataset
from services.dataset_service import DatasetService

View File

@ -1,8 +1,7 @@
from controllers.service_api import api
from flask import current_app
from flask_restful import Resource
from controllers.service_api import api
class IndexApi(Resource):
def get(self):

View File

@ -1,15 +1,14 @@
# -*- coding:utf-8 -*-
import json
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from flask import current_app
from flask_restful import fields, marshal_with
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
from extensions.ext_database import db
import json
class AppParameterApi(WebApiResource):
"""Resource for app variables."""

View File

@ -2,8 +2,13 @@ import time
from typing import Generator, List, Optional, Tuple, Union, cast
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, \
PromptTemplateEntity, ExternalDataVariableEntity, ApplicationGenerateEntity, InvokeFrom
from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity,
ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity,
PromptTemplateEntity)
from core.features.annotation_reply import AnnotationReplyFeature
from core.features.external_data_fetch import ExternalDataFetchFeature
from core.features.hosting_moderation import HostingModerationFeature
from core.features.moderation import ModerationFeature
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -11,12 +16,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.features.hosting_moderation import HostingModerationFeature
from core.features.moderation import ModerationFeature
from core.features.external_data_fetch import ExternalDataFetchFeature
from core.features.annotation_reply import AnnotationReplyFeature
from core.prompt.prompt_transform import PromptTransform
from models.model import App, MessageAnnotation, Message
from models.model import App, Message, MessageAnnotation
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,

View File

@ -3,19 +3,19 @@ import logging
from typing import cast
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AgentEntity
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -169,7 +169,7 @@ class AssistantApplicationRunner(AppRunner):
# load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id,
tanent_id=application_generate_entity.tenant_id)
tenant_id=application_generate_entity.tenant_id)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -194,6 +194,13 @@ class AssistantApplicationRunner(AppRunner):
memory=memory,
)
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
@ -209,9 +216,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_cot_runner.run(
model_instance=model_instance,
conversation=conversation,
message=message,
query=query,
@ -229,10 +236,10 @@ class AssistantApplicationRunner(AppRunner):
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_fc_runner.run(
model_instance=model_instance,
conversation=conversation,
message=message,
query=query,
@ -246,13 +253,13 @@ class AssistantApplicationRunner(AppRunner):
agent=True
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables:
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tanent_id
ToolConversationVariables.tenant_id == tenant_id
).first()
if tool_variables:
@ -263,7 +270,7 @@ class AssistantApplicationRunner(AppRunner):
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tanent_id,
tenant_id=tenant_id,
variables_str='[]',
)
db.session.add(tool_variables)

View File

@ -4,8 +4,7 @@ from typing import Optional
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import (ApplicationGenerateEntity, DatasetEntity,
InvokeFrom, ModelConfigEntity)
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance

View File

@ -6,21 +6,21 @@ from typing import Generator, Optional, Union, cast
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent,
QueueMessageFileEvent, QueueAgentMessageEvent)
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent,
QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent,
QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent,
QueueRetrieverResourcesEvent, QueueStopEvent)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
PromptMessage, PromptMessageContentType, PromptMessageRole,
TextPromptMessageContent)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.prompt_template import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from events.message_event import message_was_created
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought, MessageFile

View File

@ -9,11 +9,12 @@ from core.app_runner.basic_app_runner import BasicApplicationRunner
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
from core.entities.application_entities import (AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentToolEntity,
ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentPromptEntity,
AgentToolEntity, ApplicationGenerateEntity,
AppOrchestrationConfigEntity, DatasetEntity,
DatasetRetrieveConfigEntity, ExternalDataVariableEntity,
FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity,
SensitiveWordAvoidanceEntity, AgentPromptEntity)
SensitiveWordAvoidanceEntity)
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileObj

View File

@ -4,10 +4,10 @@ from enum import Enum
from typing import Any, Generator
from core.entities.application_entities import InvokeFrom
from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent,
QueueMessage, QueueMessageEndEvent, QueueMessageEvent,
QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent,
QueueStopEvent, QueueMessageFileEvent, QueueAgentMessageEvent)
from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent,
QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent,
QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent,
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from extensions.ext_redis import redis_client
from models.model import MessageAgentThought, MessageFile

View File

@ -1,9 +1,10 @@
import os
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from pydantic import BaseModel
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
"""Callback Handler that prints to std out."""

View File

@ -1,5 +1,6 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document

View File

@ -8,9 +8,8 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db
from langchain.embeddings.base import Embeddings
from extensions.ext_redis import redis_client
from langchain.embeddings.base import Embeddings
from libs import helper
from models.dataset import Embedding
from sqlalchemy.exc import IntegrityError

View File

@ -1,12 +1,11 @@
from enum import Enum
from typing import Optional, Any, cast, Literal, Union
from pydantic import BaseModel
from typing import Any, Literal, Optional, Union, cast
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import AIModelEntity
from pydantic import BaseModel
class ModelConfigEntity(BaseModel):

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel
class QuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
CREDITS = 'credits'
class SystemConfigurationStatus(Enum):

View File

@ -1,27 +1,26 @@
import logging
from typing import cast, Optional, List
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
from pydantic import BaseModel, Field
from typing import List, Optional, cast
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom,
ModelConfigEntity)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from models.dataset import Dataset
from models.model import Message
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)

View File

@ -1,34 +1,32 @@
import logging
import json
from typing import Optional, List, Tuple, Union
import logging
from datetime import datetime
from mimetypes import guess_extension
from typing import List, Optional, Tuple, Union, cast
from core.app_runner.app_runner import AppRunner
from extensions.ext_database import db
from models.model import MessageAgentThought, Message, MessageFile
from models.tools import ToolConversationVariables
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \
ToolRuntimeVariablePool, ToolParamter
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ModelConfigEntity, AgentEntity, AgentToolEntity
from core.application_queue_manager import ApplicationQueueManager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.entities.application_entities import (AgentEntity, AgentToolEntity, ApplicationGenerateEntity,
AppOrchestrationConfigEntity, InvokeFrom, ModelConfigEntity)
from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import (ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter,
ToolRuntimeVariablePool)
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.model import Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -45,6 +43,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
"""
Agent runner
@ -71,6 +70,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
@ -95,9 +95,17 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought.message_id == self.message.id,
).count()
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
"""
Repacket app orchestration config
Repack app orchestration config
"""
if app_orchestration_config.prompt_template.simple_prompt_template is None:
app_orchestration_config.prompt_template.simple_prompt_template = ''
@ -113,7 +121,7 @@ class BaseAssistantApplicationRunner(AppRunner):
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please dirct user to check it."
result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now."
@ -128,7 +136,7 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tanent_id=self.application_generate_entity.tenant_id,
tenant_id=self.application_generate_entity.tenant_id,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
@ -172,20 +180,20 @@ class BaseAssistantApplicationRunner(AppRunner):
for parameter in parameters:
parameter_type = 'string'
enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING:
if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.FORM:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
@ -194,7 +202,7 @@ class BaseAssistantApplicationRunner(AppRunner):
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParamter.ToolParameterType.SELECT:
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
@ -202,7 +210,7 @@ class BaseAssistantApplicationRunner(AppRunner):
# convert tool parameter config to correct type
try:
if parameter.type == ToolParamter.ToolParameterType.NUMBER:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
@ -213,11 +221,11 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]:
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType:
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
@ -225,7 +233,7 @@ class BaseAssistantApplicationRunner(AppRunner):
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParamter.ToolParameterForm.LLM:
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
@ -279,20 +287,20 @@ class BaseAssistantApplicationRunner(AppRunner):
for parameter in tool_runtime_parameters:
parameter_type = 'string'
enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING:
if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.LLM:
if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',

View File

@ -1,27 +1,23 @@
import json
import logging
import re
from typing import Literal, Union, Generator, Dict, List
from typing import Dict, Generator, List, Literal, Union
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
from core.application_queue_manager import PublishFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \
UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
from core.model_manager import ModelInstance
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
ToolProviderCredentialValidationError
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
ToolProviderCredentialValidationError, ToolProviderNotFoundError)
from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message,
query: str,
) -> Union[Generator, LLMResult]:
@ -29,7 +25,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
Run Cot agent application
"""
app_orchestration_config = self.app_orchestration_config
self._repacket_app_orchestration_config(app_orchestration_config)
self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: List[AgentScratchpadUnit] = []
@ -72,7 +68,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
}
final_answer = ''
def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
@ -82,6 +78,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
@ -104,7 +102,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
prompt_messages = self._originze_cot_prompt_messages(
prompt_messages = self._organize_cot_prompt_messages(
mode=app_orchestration_config.model_config.mode,
prompt_messages=prompt_messages,
tools=prompt_messages_tools,
@ -137,7 +135,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# get llm usage
if llm_result.usage:
increse_usage(llm_usage, llm_result.usage)
increase_usage(llm_usage, llm_result.usage)
# publish agent thought if it's first iteration
if iteration_step == 1:
@ -207,7 +205,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
try:
tool_response = tool_instance.invoke(
user_id=self.user_id,
tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
)
# transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response)
@ -225,15 +223,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
error_response = f"Plese check your tool provider credentials"
error_response = f"Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParamterValidationError
ToolParameterValidationError
) as e:
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
@ -390,7 +388,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
if action_name and action_input is not None:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
@ -468,7 +466,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration")
def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
"""
convert agent scratchpad list to str
"""
@ -480,7 +478,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return result
def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"],
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
agent_scratchpad: List[AgentScratchpadUnit],
@ -489,7 +487,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
input: str,
) -> List[PromptMessage]:
"""
originze chain of thought prompt messages, a standard prompt message is like:
organize chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
{{instruction}}
@ -527,7 +525,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
.replace("{{tools}}", tools_str) \
.replace("{{tool_names}}", tool_names)
# originze prompt messages
# organize prompt messages
if mode == "chat":
# override system message
overrided = False
@ -558,7 +556,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return prompt_messages
elif mode == "completion":
# parse agent scratchpad
agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad)
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
# parse prompt messages
return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction)

View File

@ -1,27 +1,21 @@
import json
import logging
from typing import Any, Dict, Generator, List, Tuple, Union
from typing import Union, Generator, Dict, Any, Tuple, List
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
ToolProviderCredentialValidationError
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, ToolPromptMessage, UserPromptMessage)
from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
ToolProviderCredentialValidationError, ToolProviderNotFoundError)
from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message,
query: str,
) -> Generator[LLMResultChunk, None, None]:
@ -81,6 +75,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
@ -96,17 +92,16 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_input='',
messages_ids=message_file_ids
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop,
stream=True,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
@ -122,11 +117,45 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage = None
for chunk in chunks:
if self.stream_tool_call:
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
# check if there is any tool call
if self.check_tool_calls(chunk):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
@ -138,18 +167,46 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += chunk.delta.message.content
response += result.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
if not result.message.content:
result.message.content = ''
yield chunk
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
)
)
if tool_calls:
prompt_messages.append(AssistantPromptMessage(
content='',
name='',
tool_calls=[AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls]
))
# save thought
self.save_agent_thought(
@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
final_answer += response + '\n'
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
@ -184,7 +247,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
try:
tool_invoke_message = tool_instance.invoke(
user_id=self.user_id,
tool_paramters=tool_call_args,
tool_parameters=tool_call_args,
)
# transform tool invoke message to get LLM friendly message
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
@ -203,15 +266,15 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
message_file_ids.append(message_file.id)
except ToolProviderCredentialValidationError as e:
error_response = f"Plese check your tool provider credentials"
error_response = f"Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParamterValidationError
ToolParameterValidationError
) as e:
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
))
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,

View File

@ -1,11 +1,11 @@
from typing import Dict, List, Optional, Union
import requests
from core.file.file_obj import FileObj, FileTransferMethod, FileType, FileBelongsTo
from services.file_service import IMAGE_EXTENSIONS
from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
from extensions.ext_database import db
from models.account import Account
from models.model import AppModelConfig, EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:

View File

@ -2,7 +2,7 @@ from typing import Optional
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from flask import Flask, Config
from flask import Config, Flask
from models.provider import ProviderQuotaType
from pydantic import BaseModel
@ -20,10 +20,6 @@ class TrialHostingQuota(HostingQuota):
class PaidHostingQuota(HostingQuota):
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
stripe_price_id: str = None
increase_quota: int = 1
min_quantity: int = 20
max_quantity: int = 100
class FreeHostingQuota(HostingQuota):
@ -102,7 +98,7 @@ class HostingConfiguration:
)
def init_openai(self, app_config: Config) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
quota_unit = QuotaUnit.CREDITS
quotas = []
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
@ -114,6 +110,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="whisper-1", model_type=ModelType.SPEECH2TEXT),
]
@ -122,10 +120,20 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_quota = PaidHostingQuota(
stripe_price_id=app_config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
increase_quota=int(app_config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
min_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
max_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
restrict_models=[
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
)
quotas.append(paid_quota)
@ -164,12 +172,7 @@ class HostingConfiguration:
quotas.append(trial_quota)
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
paid_quota = PaidHostingQuota(
stripe_price_id=app_config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
increase_quota=int(app_config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
min_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
max_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
)
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:

View File

@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.model_manager import ModelManager, ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel

View File

@ -12,8 +12,8 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager

View File

@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum):

View File

@ -1,13 +1,12 @@
import uuid
import hashlib
import subprocess
import uuid
from abc import abstractmethod
from typing import Optional
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.entities.model_entities import ModelPropertyKey
class TTSModel(AIModel):

View File

@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={

View File

@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
index = 0
full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model
system_fingerprint = None
completion = ''
@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
delta.delta.function_call is None:
continue
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else:
raise ValueError(f"Got unknown type {message}")
if message.name is not None:
if message.name:
message_dict["name"] = message.name
return message_dict
@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(tool.get("type")))
num_tokens += len(encoding.encode('function'))
# calculate num tokens for function object

View File

@ -8,9 +8,9 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
- name: top_p
use_template: top_p
- name: topK
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K

View File

@ -8,9 +8,9 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
- name: top_p
use_template: top_p
- name: topK
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K

View File

@ -8,9 +8,9 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
- name: top_p
use_template: top_p
- name: topK
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K

View File

@ -1,15 +1,14 @@
import json
import logging
from typing import Generator, List, Optional, Union
import boto3
from botocore.exceptions import ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, UnknownServiceError
from botocore.config import Config
import json
from botocore.exceptions import (ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError,
UnknownServiceError)
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
PromptMessageTool, SystemPromptMessage, UserPromptMessage)
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -250,9 +249,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
invoke = runtime_client.invoke_model
try:
body_jsonstr=json.dumps(payload)
response = invoke(
body=json.dumps(payload),
modelId=model,
contentType="application/json",
accept= "*/*",
body=body_jsonstr
)
except ClientError as ex:
error_code = ex.response['Error']['Code']
@ -385,7 +387,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if not chunk:
exception_name = next(iter(event))
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
payload = json.loads(chunk.get('bytes').decode())
@ -396,7 +397,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
finish_reason = payload.get("completion_reason")
elif model_prefix == "anthropic":
content_delta = payload
content_delta = payload.get("completion")
finish_reason = payload.get("stop_reason")
elif model_prefix == "cohere":
@ -410,12 +411,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
index += 1
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content = content_delta if content_delta else '',
)
index += 1
if not finish_reason:
yield LLMResultChunk(
model=model,

View File

@ -5,7 +5,8 @@ from typing import Generator, List, Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
PromptMessageTool, SystemPromptMessage, UserPromptMessage)
PromptMessageTool, SystemPromptMessage, ToolPromptMessage,
UserPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -194,6 +195,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")

View File

@ -1,19 +1,18 @@
import logging
from typing import Generator, List, Optional, Union, cast, Tuple
from typing import Generator, List, Optional, Tuple, Union, cast
import cohere
from cohere.responses import Chat, Generations
from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd
from cohere.responses.generation import StreamingText, StreamingGenerations
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
from cohere.responses.generation import StreamingGenerations, StreamingText
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
PromptMessageContentType, SystemPromptMessage,
TextPromptMessageContent, UserPromptMessage,
PromptMessageTool)
PromptMessageContentType, PromptMessageTool,
SystemPromptMessage, TextPromptMessageContent,
UserPromptMessage)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \
InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@ -4,11 +4,10 @@ from typing import Optional, Tuple
import cohere
import numpy as np
from cohere.responses import Tokens
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel

View File

@ -4,6 +4,8 @@ label:
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 16384

View File

@ -4,6 +4,8 @@ label:
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
"""
def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
continue
for choice in choices:
print(choice)
message = choice['delta']
yield MinimaxMessage(
content=message,

View File

@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
"""
def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
**extra_kwargs
}
if tools:
body['functions'] = tools
body['function_call'] = { 'type': 'auto' }
try:
response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
"""
handle stream chat generate response
"""
function_call_storage = None
for line in response.iter_lines():
if not line:
continue
@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
msg = data['base_resp']['status_msg']
self._handle_error(code, msg)
if data['reply']:
if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens']
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
'total_tokens': total_tokens
}
message.stop_reason = data['choices'][0]['finish_reason']
if function_call_storage:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = function_call_storage
yield function_call_message
yield message
return
@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
continue
for choice in choices:
message = choice['messages'][0]['text']
if not message:
continue
message = choice['messages'][0]
if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None
yield MinimaxMessage(
content=message,
role=MinimaxMessage.Role.ASSISTANT.value
)
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
if 'function_call' in message:
minimax_message.function_call = message['function_call']
if 'text' in message:
minimax_message.content = message['text']
yield minimax_message

View File

@ -2,7 +2,7 @@ from typing import Generator, List
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
SystemPromptMessage, ToolPromptMessage, UserPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
"""
client: MinimaxChatCompletionPro = self.model_apis[model]()
if tools:
tools = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
response = client.generate(
model=model,
api_key=credentials['minimax_api_key'],
@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
elif isinstance(prompt_message, UserPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.function_call={
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
return message
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
elif isinstance(prompt_message, ToolPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
else:
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
finish_reason=message.stop_reason if message.stop_reason else None,
),
)
elif message.function_call:
if 'name' not in message.function_call or 'arguments' not in message.function_call:
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content='',
tool_calls=[AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call['name'],
arguments=message.function_call['arguments']
)
)]
),
),
)
else:
yield LLMResultChunk(
model=model,

View File

@ -7,13 +7,23 @@ class MinimaxMessage:
USER = 'USER'
ASSISTANT = 'BOT'
SYSTEM = 'SYSTEM'
FUNCTION = 'FUNCTION'
role: str = Role.USER.value
content: str
usage: Dict[str, int] = None
stop_reason: str = ''
function_call: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
'sender_type': 'BOT',
'sender_name': '专家',
'text': '',
'function_call': self.function_call
}
return {
'sender_type': self.role,
'sender_name': '' if self.role == 'USER' else '专家',

View File

@ -2,21 +2,20 @@ import json
import logging
import re
from decimal import Decimal
from typing import Optional, Generator, Union, List, cast
from typing import Generator, List, Optional, Union, cast
from urllib.parse import urljoin
import requests
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
TextPromptMessageContent, SystemPromptMessage
from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
PromptMessage, PromptMessageContentType, PromptMessageTool,
SystemPromptMessage, TextPromptMessageContent,
UserPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject,
ModelFeature, ModelPropertyKey, ModelType, ParameterRule,
ParameterType, PriceConfig)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@ -11,7 +11,7 @@ help:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
en_US: https://docs.dify.ai/advanced/model-configuration/ollama
en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
supported_model_types:
- llm
- text-embedding

View File

@ -1,19 +1,18 @@
import json
import logging
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
import numpy as np
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
PriceConfig
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
PriceConfig, PriceType)
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 16385

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 4096

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000

View File

@ -6,6 +6,7 @@ model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 8192

View File

@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
else:
raise ValueError(f"Got unknown type {message}")
if message.name is not None:
if message.name:
message_dict["name"] = message.name
return message_dict

View File

@ -26,3 +26,4 @@ pricing:
output: '0.002'
unit: '0.001'
currency: USD
deprecated: true

View File

@ -1,16 +1,15 @@
import concurrent.futures
from functools import reduce
from io import BytesIO
from typing import Optional
from functools import reduce
from pydub import AudioSegment
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
from flask import Response, stream_with_context
from openai import OpenAI
import concurrent.futures
from pydub import AudioSegment
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):

View File

@ -372,15 +372,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
if finish_reason is not None:
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=choice.get('text', '')),
finish_reason=finish_reason
)
else:
continue
delta_content = delta.get('content')
if delta_content is None or delta_content == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
@ -393,11 +387,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':

View File

@ -41,7 +41,7 @@ class OpenLLMGenerate(object):
if not server_url:
raise InvalidAuthenticationError('Invalid server URL')
defautl_llm_config = {
default_llm_config = {
"max_new_tokens": 128,
"min_length": 0,
"early_stopping": False,
@ -75,19 +75,19 @@ class OpenLLMGenerate(object):
}
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
defautl_llm_config['max_new_tokens'] = model_parameters['max_tokens']
default_llm_config['max_new_tokens'] = model_parameters['max_tokens']
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
defautl_llm_config['temperature'] = model_parameters['temperature']
default_llm_config['temperature'] = model_parameters['temperature']
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
defautl_llm_config['top_p'] = model_parameters['top_p']
default_llm_config['top_p'] = model_parameters['top_p']
if 'top_k' in model_parameters and type(model_parameters['top_k']) == int:
defautl_llm_config['top_k'] = model_parameters['top_k']
default_llm_config['top_k'] = model_parameters['top_k']
if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool:
defautl_llm_config['use_cache'] = model_parameters['use_cache']
default_llm_config['use_cache'] = model_parameters['use_cache']
headers = {
'Content-Type': 'application/json',
@ -104,7 +104,7 @@ class OpenLLMGenerate(object):
data = {
'stop': stop if stop else [],
'prompt': '\n'.join([message.content for message in prompt_messages]),
'llm_config': defautl_llm_config,
'llm_config': default_llm_config,
}
try:

View File

@ -35,6 +35,10 @@ class SparkLLMClient:
'spark-3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
}
}

View File

@ -0,0 +1,4 @@
- spark-3.5
- spark-3
- spark-1.5
- spark-2

View File

@ -0,0 +1,33 @@
model: spark-3.5
label:
en_US: Spark V3.5
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 2048
min: 1
max: 8192
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: 模型回答的tokens的最大长度。
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择⼀个(⾮等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false

View File

@ -16,26 +16,5 @@ class SparkProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='spark-1.5',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='spark-3',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
# ignore credentials validation because every model has their own spark quota pool
pass

View File

@ -1,14 +1,13 @@
from typing import Generator, List, Optional, Union
from dashscope import get_tokenizer
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dashscope import get_tokenizer
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
from dashscope.common.error import (AuthenticationError, InvalidParameter, RequestFailure, ServiceUnavailableError,
UnsupportedHTTPMethod, UnsupportedModel)
@ -168,7 +167,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, responses: list[Generator],
def _handle_generate_stream_response(self, model: str, credentials: dict, responses: Generator,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -182,7 +181,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
for index, response in enumerate(responses):
resp_finish_reason = response.output.finish_reason
resp_content = response.output.text
useage = response.usage
usage = response.usage
if resp_finish_reason is None and (resp_content is None or resp_content == ''):
continue
@ -194,7 +193,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
if resp_finish_reason is not None:
# transform usage
usage = self._calc_response_usage(model, credentials, useage.input_tokens, useage.output_tokens)
usage = self._calc_response_usage(model, credentials, usage.input_tokens, usage.output_tokens)
yield LLMResultChunk(
model=model,

View File

@ -1,16 +1,15 @@
import concurrent.futures
from functools import reduce
from io import BytesIO
from typing import Optional
from functools import reduce
from pydub import AudioSegment
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
import dashscope
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
from flask import Response, stream_with_context
import concurrent.futures
from pydub import AudioSegment
class TongyiText2SpeechModel(_CommonTongyi, TTSModel):

View File

@ -3,15 +3,15 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
ParameterRule, ParameterType)
SystemPromptMessage, ToolPromptMessage, UserPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey,
ModelType, ParameterRule, ParameterType)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter)
from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter)
from core.model_runtime.utils import helper
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError)
@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99
return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user,
@ -65,6 +71,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
if extra_param.support_function_call:
credentials['support_function_call'] = True
if extra_param.context_length:
credentials['context_length'] = extra_param.context_length
except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
@ -220,6 +232,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
@ -237,7 +252,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
)
),
),
ParameterRule(
name='top_p',
@ -282,6 +297,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
model=model,
@ -290,8 +308,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)
@ -310,6 +332,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
api_key='abc',

View File

@ -2,12 +2,13 @@ import time
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
server_url = credentials['server_url']
model_uid = credentials['model_uid']
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError:
except (InvokeAuthorizationError, RuntimeError):
raise CredentialsValidateFailedError('Invalid api key')
@property
@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={},
model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[]
)

View File

@ -1,3 +1,4 @@
from os import path
from threading import Lock
from time import time
from typing import List
@ -12,11 +13,18 @@ class XinferenceModelExtraParameter(object):
model_format: str
model_handle_type: str
model_ability: List[str]
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens
self.context_length = context_length
cache = {}
cache_lock = Lock()
@ -49,9 +57,9 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type
"""
url = f'{server_url}/v1/models/{model_uid}'
url = path.join(server_url, 'v1/models', model_uid)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
@ -66,10 +74,12 @@ class XinferenceHelper:
response_json = response.json()
model_format = response_json['model_format']
model_ability = response_json['model_ability']
model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json.get('model_ability', [])
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm'
elif 'generate' in model_ability:
model_handle_type = 'generate'
@ -78,8 +88,16 @@ class XinferenceHelper:
else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048)
return XinferenceModelExtraParameter(
model_format=model_format,
model_handle_type=model_handle_type,
model_ability=model_ability
model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens,
context_length=context_length
)

View File

@ -2,6 +2,10 @@ model: glm-3-turbo
label:
en_US: glm-3-turbo
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:
@ -28,15 +32,3 @@ parameter_rules:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回json_string 返回标准的 JSON 字符串text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string

View File

@ -2,6 +2,10 @@ model: glm-4
label:
en_US: glm-4
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
parameter_rules:
@ -28,15 +32,3 @@ parameter_rules:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回json_string 返回标准的 JSON 字符串text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string

View File

@ -30,15 +30,3 @@ parameter_rules:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: return_type
label:
zh_Hans: 回复类型
en_US: Return Type
type: string
help:
zh_Hans: 用于控制每次返回内容的类型,空或者没有此字段时默认按照 json_string 返回json_string 返回标准的 JSON 字符串text 返回原始的文本内容。
en_US: Used to control the type of content returned each time. When it is empty or does not have this field, it will be returned as json_string by default. json_string returns a standard JSON string, and text returns the original text content.
required: false
options:
- text
- json_string

View File

@ -2,16 +2,19 @@ import json
from typing import Any, Dict, Generator, List, Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
PromptMessage, PromptMessageContentType, PromptMessageRole,
PromptMessageTool, SystemPromptMessage,
TextPromptMessageContent, ToolPromptMessage,
UserPromptMessage)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils import helper
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
@ -194,6 +197,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
'content': prompt_message.content,
'tool_call_id': prompt_message.tool_call_id
})
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content,
'tool_calls': [
{
'id': tool_call.id,
'type': tool_call.type,
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments
}
} for tool_call in prompt_message.tool_calls
]
})
else:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content
})
else:
params['messages'].append({
'role': prompt_message.role.value,

View File

@ -5,8 +5,8 @@ from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from langchain.schema.language_model import _get_token_ids_default_method

View File

@ -1,17 +1,6 @@
from ._client import ZhipuAI
from .core._errors import (
ZhipuAIError,
APIStatusError,
APIRequestFailedError,
APIAuthenticationError,
APIReachLimitError,
APIInternalError,
APIServerFlowExceedError,
APIResponseError,
APIResponseValidationError,
APITimeoutError,
)
from .__version__ import __version__
from ._client import ZhipuAI
from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError,
APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError,
APITimeoutError, ZhipuAIError)

Some files were not shown because too many files have changed in this diff Show More