mirror of
https://github.com/langgenius/dify.git
synced 2026-01-21 12:35:21 +08:00
Compare commits
48 Commits
0.7.0
...
fix/note-n
| Author | SHA1 | Date | |
|---|---|---|---|
| d688bebb1a | |||
| f2ad16cec5 | |||
| 68dc6d5bc3 | |||
| acd72e3ab2 | |||
| bbb6fcc4f0 | |||
| fbf31b5d52 | |||
| a0c689c273 | |||
| bfd905602f | |||
| a0a67873aa | |||
| 6cd8ab0cbc | |||
| 5350b1d938 | |||
| baaa3f7f42 | |||
| 4d4af00399 | |||
| 3a33062405 | |||
| 7d4a0a417a | |||
| 5a729a69cd | |||
| dbc1ae45de | |||
| 9e6b755f62 | |||
| a2fafee53a | |||
| c7df6783df | |||
| fcb6921b57 | |||
| 135dcfa3e5 | |||
| acfab01dcf | |||
| 6fdbc7dbf3 | |||
| d1a6702aa4 | |||
| 28944ef6c1 | |||
| 6e7f5fae09 | |||
| ed85d8281a | |||
| f3d3a3a5db | |||
| c89697c49c | |||
| 9414143b5f | |||
| d07b2b9915 | |||
| 04131f86df | |||
| 2d89b7d0a9 | |||
| 603a89055c | |||
| 3f9720bca0 | |||
| 7619850855 | |||
| 3571292fbf | |||
| 8f16165f92 | |||
| 6ff7fd80a1 | |||
| 5aa373dc04 | |||
| 32dc963556 | |||
| 8f5d8397f9 | |||
| 681ec6f845 | |||
| d2ccd8ba53 | |||
| 7f67cb93ec | |||
| d29b32fce2 | |||
| 101db126c8 |
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -45,6 +45,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Ruff formatter check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api ruff format --check ./api
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
161
api/app.py
161
api/app.py
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
|
||||
if os.name == "nt":
|
||||
os.system('tzutil /s "UTC"')
|
||||
else:
|
||||
os.environ['TZ'] = 'UTC'
|
||||
os.environ["TZ"] = "UTC"
|
||||
time.tzset()
|
||||
|
||||
|
||||
@ -70,13 +70,14 @@ class DifyApp(Flask):
|
||||
# -------------
|
||||
|
||||
|
||||
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
|
||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def create_flask_app_with_configs() -> Flask:
|
||||
"""
|
||||
create a raw flask app
|
||||
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ''
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
|
||||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
app.secret_key = app.config["SECRET_KEY"]
|
||||
|
||||
log_handlers = None
|
||||
log_file = app.config.get('LOG_FILE')
|
||||
log_file = app.config.get("LOG_FILE")
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
@ -111,23 +112,24 @@ def create_app() -> Flask:
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5
|
||||
backupCount=5,
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
datefmt=app.config.get('LOG_DATEFORMAT'),
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True
|
||||
force=True,
|
||||
)
|
||||
log_tz = app.config.get('LOG_TZ')
|
||||
log_tz = app.config.get("LOG_TZ")
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
@ -162,24 +164,24 @@ def initialize_extensions(app):
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in ['console', 'inner_api']:
|
||||
if request.blueprint not in ["console", "inner_api"]:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get('_token')
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized('Invalid Authorization token.')
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if account:
|
||||
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(json.dumps({
|
||||
'code': 'unauthorized',
|
||||
'message': "Unauthorized."
|
||||
}), status=401, content_type="application/json")
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# register blueprint routers
|
||||
@ -204,38 +207,36 @@ def register_blueprints(app):
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(web_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(console_app_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp,
|
||||
allow_headers=['Content-Type'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
@ -245,29 +246,29 @@ def register_blueprints(app):
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if app.config.get('TESTING'):
|
||||
if app.config.get("TESTING"):
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie('remember_token', '', expires=0)
|
||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
||||
response.set_cookie("remember_token", "", expires=0)
|
||||
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||
return response
|
||||
|
||||
|
||||
@app.route('/health')
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(json.dumps({
|
||||
'pid': os.getpid(),
|
||||
'status': 'ok',
|
||||
'version': app.config['CURRENT_VERSION']
|
||||
}), status=200, content_type="application/json")
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@app.route('/threads')
|
||||
@app.route("/threads")
|
||||
def threads():
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
@ -278,32 +279,34 @@ def threads():
|
||||
thread_id = thread.ident
|
||||
is_alive = thread.is_alive()
|
||||
|
||||
thread_list.append({
|
||||
'name': thread_name,
|
||||
'id': thread_id,
|
||||
'is_alive': is_alive
|
||||
})
|
||||
thread_list.append(
|
||||
{
|
||||
"name": thread_name,
|
||||
"id": thread_id,
|
||||
"is_alive": is_alive,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
'pid': os.getpid(),
|
||||
'thread_num': num_threads,
|
||||
'threads': thread_list
|
||||
"pid": os.getpid(),
|
||||
"thread_num": num_threads,
|
||||
"threads": thread_list,
|
||||
}
|
||||
|
||||
|
||||
@app.route('/db-pool-stat')
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat():
|
||||
engine = db.engine
|
||||
return {
|
||||
'pid': os.getpid(),
|
||||
'pool_size': engine.pool.size(),
|
||||
'checked_in_connections': engine.pool.checkedin(),
|
||||
'checked_out_connections': engine.pool.checkedout(),
|
||||
'overflow_connections': engine.pool.overflow(),
|
||||
'connection_timeout': engine.pool.timeout(),
|
||||
'recycle_time': db.engine.pool._recycle
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in_connections": engine.pool.checkedin(),
|
||||
"checked_out_connections": engine.pool.checkedout(),
|
||||
"overflow_connections": engine.pool.overflow(),
|
||||
"connection_timeout": engine.pool.timeout(),
|
||||
"recycle_time": db.engine.pool._recycle,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5001)
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
415
api/commands.py
415
api/commands.py
@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
|
||||
@click.option('--new-password', prompt=True, help='the new password.')
|
||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
|
||||
@click.option("--new-password", prompt=True, help="the new password.")
|
||||
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
||||
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(
|
||||
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
|
||||
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
|
||||
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
|
||||
|
||||
|
||||
@click.command('reset-email', help='Reset the account email.')
|
||||
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
|
||||
@click.option('--new-email', prompt=True, help='the new email.')
|
||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
|
||||
@click.option("--new-email", prompt=True, help="the new email.")
|
||||
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
||||
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(
|
||||
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
|
||||
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
|
||||
|
||||
|
||||
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
|
||||
'After the reset, all LLM credentials will become invalid, '
|
||||
'requiring re-entry.'
|
||||
'Only support SELF_HOSTED mode.')
|
||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
||||
' this operation cannot be rolled back!', fg='red'))
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
"""
|
||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
||||
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
||||
@click.command("vdb-migrate", help="migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ['knowledge', 'all']:
|
||||
if scope in ["knowledge", "all"]:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ['annotation', 'all']:
|
||||
if scope in ["annotation", "all"]:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
||||
click.echo(click.style("Start migrate annotation data.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
|
||||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
apps = db.session.query(App).filter(
|
||||
App.status == 'normal'
|
||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.filter(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
click.echo(
|
||||
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo('Create app annotation index: {}'.format(app.id))
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app.id
|
||||
).first()
|
||||
click.echo("Create app annotation index: {}".format(app.id))
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
||||
click.echo("App annotation setting is disabled: {}".format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
).first()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
||||
click.echo("App annotation collection binding is not exist: {}".format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app.id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
||||
fg='green'))
|
||||
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
||||
fg='red'))
|
||||
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
||||
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
||||
click.echo(f"Successfully migrated app annotation {app.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.style(
|
||||
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
click.echo(click.style("Start migrate vector db.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
datasets = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.indexing_technique == "high_quality")
|
||||
.order_by(Dataset.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
click.echo(
|
||||
f"Processing the {total_count} dataset {dataset.id}. "
|
||||
+ f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
if dataset.index_struct_dict["type"] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ''
|
||||
collection_name = ""
|
||||
if vector_type == VectorType.WEAVIATE:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.QDRANT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == VectorType.MILVUS:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.MILVUS,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.RELYT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'relyt',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.TENCENT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.TENCENT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.PGVECTOR:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.PGVECTOR,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.OPENSEARCH:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.OPENSEARCH,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ELASTICSEARCH:
|
||||
dataset_id = dataset.id
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'elasticsearch',
|
||||
"vector_store": {"class_prefix": index_name}
|
||||
}
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='red'))
|
||||
click.style(
|
||||
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
|
||||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
||||
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
||||
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
||||
click.echo(f"Successfully migrated dataset {dataset.id}.")
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
||||
fg='green'))
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
|
||||
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
|
||||
def convert_to_agent_apps():
|
||||
"""
|
||||
Convert Agent Assistant to Agent App.
|
||||
"""
|
||||
click.echo(click.style('Start convert to agent apps.', fg='green'))
|
||||
click.echo(click.style("Start convert to agent apps.", fg="green"))
|
||||
|
||||
proceeded_app_ids = []
|
||||
|
||||
@ -466,7 +480,7 @@ def convert_to_agent_apps():
|
||||
break
|
||||
|
||||
for app in apps:
|
||||
click.echo('Converting app: {}'.format(app.id))
|
||||
click.echo("Converting app: {}".format(app.id))
|
||||
|
||||
try:
|
||||
app.mode = AppMode.AGENT_CHAT.value
|
||||
@ -478,137 +492,139 @@ def convert_to_agent_apps():
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
|
||||
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
|
||||
str(e)), fg='red'))
|
||||
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
|
||||
|
||||
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
|
||||
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
||||
|
||||
|
||||
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
|
||||
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
|
||||
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
|
||||
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
|
||||
def add_qdrant_doc_id_index(field: str):
|
||||
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
|
||||
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if vector_type != "qdrant":
|
||||
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
|
||||
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
|
||||
return
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
if not bindings:
|
||||
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
|
||||
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError('Qdrant url is required.')
|
||||
raise ValueError("Qdrant url is required.")
|
||||
qdrant_config = QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
)
|
||||
try:
|
||||
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
||||
# create payload index
|
||||
client.create_payload_index(binding.collection_name, field,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||
create_count += 1
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
|
||||
click.echo(
|
||||
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
|
||||
)
|
||||
continue
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style('Failed to create qdrant client.', fg='red'))
|
||||
click.echo(click.style("Failed to create qdrant client.", fg="red"))
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} collection indexes.',
|
||||
fg='green'))
|
||||
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
|
||||
|
||||
|
||||
@click.command('create-tenant', help='Create account and tenant.')
|
||||
@click.option('--email', prompt=True, help='The email address of the tenant account.')
|
||||
@click.option('--language', prompt=True, help='Account language, default: en-US.')
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
if not email:
|
||||
click.echo(click.style('Sorry, email is required.', fg='red'))
|
||||
click.echo(click.style("Sorry, email is required.", fg="red"))
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
|
||||
if '@' not in email:
|
||||
click.echo(click.style('Sorry, invalid email address.', fg='red'))
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Sorry, invalid email address.", fg="red"))
|
||||
return
|
||||
|
||||
account_name = email.split('@')[0]
|
||||
account_name = email.split("@")[0]
|
||||
|
||||
if language not in languages:
|
||||
language = 'en-US'
|
||||
language = "en-US"
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
email=email,
|
||||
name=account_name,
|
||||
password=new_password,
|
||||
language=language
|
||||
)
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
click.echo(click.style('Congratulations! Account and tenant created.\n'
|
||||
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command('upgrade-db', help='upgrade the database')
|
||||
@click.command("upgrade-db", help="upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo('Preparing database migration...')
|
||||
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
|
||||
click.echo("Preparing database migration...")
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
if lock.acquire(blocking=False):
|
||||
try:
|
||||
click.echo(click.style('Start database migration.', fg='green'))
|
||||
click.echo(click.style("Start database migration.", fg="green"))
|
||||
|
||||
# run db migration
|
||||
import flask_migrate
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
click.echo(click.style('Database migration successful!', fg='green'))
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f'Database migration failed, error: {e}')
|
||||
logging.exception(f"Database migration failed, error: {e}")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
click.echo('Database migration skipped')
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
|
||||
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
|
||||
def fix_app_site_missing():
|
||||
"""
|
||||
Fix app related site missing issue.
|
||||
"""
|
||||
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
|
||||
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
|
||||
|
||||
failed_app_ids = []
|
||||
while True:
|
||||
@ -639,15 +655,14 @@ where sites.id is null limit 1000"""
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception as e:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
|
||||
logging.exception(f'Fix app related site missing issue failed, error: {e}')
|
||||
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
|
||||
logging.exception(f"Fix app related site missing issue failed, error: {e}")
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
break
|
||||
|
||||
|
||||
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
|
||||
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
|
||||
@ -1 +1 @@
|
||||
HIDDEN_VALUE = '[__HIDDEN__]'
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'zh-Hant': 'Asia/Taipei',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
'de-DE': 'Europe/Berlin',
|
||||
'ja-JP': 'Asia/Tokyo',
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
'ro-RO': 'Europe/Bucharest',
|
||||
'pl-PL': 'Europe/Warsaw',
|
||||
'hi-IN': 'Asia/Kolkata',
|
||||
'tr-TR': 'Europe/Istanbul',
|
||||
'fa-IR': 'Asia/Tehran',
|
||||
"en-US": "America/New_York",
|
||||
"zh-Hans": "Asia/Shanghai",
|
||||
"zh-Hant": "Asia/Taipei",
|
||||
"pt-BR": "America/Sao_Paulo",
|
||||
"es-ES": "Europe/Madrid",
|
||||
"fr-FR": "Europe/Paris",
|
||||
"de-DE": "Europe/Berlin",
|
||||
"ja-JP": "Asia/Tokyo",
|
||||
"ko-KR": "Asia/Seoul",
|
||||
"ru-RU": "Europe/Moscow",
|
||||
"it-IT": "Europe/Rome",
|
||||
"uk-UA": "Europe/Kyiv",
|
||||
"vi-VN": "Asia/Ho_Chi_Minh",
|
||||
"ro-RO": "Europe/Bucharest",
|
||||
"pl-PL": "Europe/Warsaw",
|
||||
"hi-IN": "Asia/Kolkata",
|
||||
"tr-TR": "Europe/Istanbul",
|
||||
"fa-IR": "Asia/Tehran",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
@ -26,6 +26,5 @@ def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
|
||||
error = ('{lang} is not a valid language.'
|
||||
.format(lang=lang))
|
||||
error = "{lang} is not a valid language.".format(lang=lang)
|
||||
raise ValueError(error)
|
||||
|
||||
@ -5,82 +5,79 @@ from models.model import AppMode
|
||||
default_app_templates = {
|
||||
# workflow default mode
|
||||
AppMode.WORKFLOW: {
|
||||
'app': {
|
||||
'mode': AppMode.WORKFLOW.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
}
|
||||
},
|
||||
|
||||
# completion default mode
|
||||
AppMode.COMPLETION: {
|
||||
'app': {
|
||||
'mode': AppMode.COMPLETION.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.COMPLETION.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
"completion_params": {},
|
||||
},
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]),
|
||||
'pre_prompt': '{{query}}'
|
||||
"user_input_form": json.dumps(
|
||||
[
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
]
|
||||
),
|
||||
"pre_prompt": "{{query}}",
|
||||
},
|
||||
|
||||
},
|
||||
|
||||
# chat default mode
|
||||
AppMode.CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
}
|
||||
}
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
# advanced-chat default mode
|
||||
AppMode.ADVANCED_CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.ADVANCED_CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
}
|
||||
"app": {
|
||||
"mode": AppMode.ADVANCED_CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
|
||||
# agent-chat default mode
|
||||
AppMode.AGENT_CHAT: {
|
||||
'app': {
|
||||
'mode': AppMode.AGENT_CHAT.value,
|
||||
'enable_site': True,
|
||||
'enable_api': True
|
||||
"app": {
|
||||
"mode": AppMode.AGENT_CHAT.value,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
'model_config': {
|
||||
'model': {
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||
|
||||
@ -61,6 +61,7 @@ class AppListApi(Resource):
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
@ -94,6 +95,7 @@ class AppImportApi(Resource):
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
@ -167,6 +169,7 @@ class AppApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
@ -208,6 +211,7 @@ class AppCopyApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -33,7 +33,7 @@ class CompletionConversationApi(Resource):
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
@ -256,7 +256,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from models.model import Site
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('title', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, location='json')
|
||||
parser.add_argument('description', type=str, required=False, location='json')
|
||||
@ -53,6 +54,7 @@ class AppSite(Resource):
|
||||
|
||||
for attr_name in [
|
||||
'title',
|
||||
'icon_type',
|
||||
'icon',
|
||||
'icon_background',
|
||||
'description',
|
||||
|
||||
@ -459,6 +459,7 @@ class ConvertToWorkflowApi(Resource):
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -53,19 +53,22 @@ class SegmentApi(DatasetApiResource):
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
if args['segments'] is not None:
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
||||
@ -6,6 +6,7 @@ from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource):
|
||||
'title': fields.String,
|
||||
'chat_color_theme': fields.String,
|
||||
'chat_color_theme_inverted': fields.Boolean,
|
||||
'icon_type': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'icon_url': AppIconUrlField,
|
||||
'description': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
|
||||
@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
:param prompt_messages: prompt messages
|
||||
:param variables_pool: variables pool
|
||||
:param db_variables: db variables
|
||||
:param model_instance: model instance
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner):
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
tool_responses = { tool: agent_thought.observation for tool in tools }
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
|
||||
@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
|
||||
@ -93,6 +93,7 @@ class DatasetConfigManager:
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs["reranking_mode"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
# db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation_id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) \
|
||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def _handle_advanced_chat_response(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -4,9 +4,6 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models import App, Message, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.commit()
|
||||
# Convert database entities to variables
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
@ -245,8 +249,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if (message.event
|
||||
and hasattr(message.event, 'metadata')
|
||||
and message.event.metadata
|
||||
and getattr(message.event, 'metadata', None)
|
||||
and message.event.metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
@ -613,7 +616,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
@ -126,7 +128,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
@ -254,6 +256,7 @@ class AppRunner:
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
@ -276,6 +279,7 @@ class AppRunner:
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish(
|
||||
@ -291,6 +295,7 @@ class AppRunner:
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
@ -366,7 +371,7 @@ class AppRunner:
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
@ -418,7 +423,7 @@ class AppRunner:
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
|
||||
@ -138,6 +138,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
:conversation conversation
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
@ -258,7 +259,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
@ -270,6 +271,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
|
||||
@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -13,11 +12,9 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
@ -32,7 +29,6 @@ __all__ = [
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'FileVariable',
|
||||
'StringVariable',
|
||||
'ArrayAnyVariable',
|
||||
'Variable',
|
||||
@ -45,11 +41,9 @@ __all__ = [
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArrayAnySegment',
|
||||
'FileSegment',
|
||||
'StringSegment',
|
||||
'ArrayStringVariable',
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArrayFileVariable',
|
||||
'ArraySegment',
|
||||
]
|
||||
|
||||
@ -2,12 +2,10 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .exc import VariableError
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -17,11 +15,9 @@ from .segments import (
|
||||
)
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
case SegmentType.FILE:
|
||||
result = FileVariable.model_validate(mapping)
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
||||
mapping = dict(mapping)
|
||||
mapping['value'] = [{'value': v} for v in value]
|
||||
result = ArrayFileVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
return ArrayAnySegment(value=value)
|
||||
if isinstance(value, FileVar):
|
||||
return FileSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
@ -5,8 +5,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
|
||||
value: int
|
||||
|
||||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
@ -108,7 +99,13 @@ class ObjectSegment(Segment):
|
||||
class ArraySegment(Segment):
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
items = []
|
||||
for item in self.value:
|
||||
if hasattr(item, 'to_markdown'):
|
||||
items.append(item.to_markdown())
|
||||
else:
|
||||
items.append(str(item))
|
||||
return '\n'.join(items)
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
@ -130,7 +127,3 @@ class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[FileSegment]
|
||||
|
||||
@ -10,8 +10,6 @@ class SegmentType(str, Enum):
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
ARRAY_FILE = 'array[file]'
|
||||
OBJECT = 'object'
|
||||
FILE = 'file'
|
||||
|
||||
GROUP = 'group'
|
||||
|
||||
@ -4,11 +4,9 @@ from core.helper import encrypter
|
||||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
|
||||
@ -99,7 +99,7 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform message files
|
||||
|
||||
@ -144,7 +144,7 @@ class MessageFileParser:
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
|
||||
@ -271,9 +271,8 @@ class ModelInstance:
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
@ -401,6 +400,10 @@ class LBModelManager:
|
||||
managed_credentials: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Load balancing model manager
|
||||
:param tenant_id: tenant_id
|
||||
:param provider: provider
|
||||
:param model_type: model_type
|
||||
:param model: model name
|
||||
:param load_balancing_configs: all load balancing configurations
|
||||
:param managed_credentials: credentials if load balancing configuration name is __inherit__
|
||||
"""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
@ -94,5 +93,16 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
'required': False,
|
||||
'options': ['JSON', 'XML'],
|
||||
}
|
||||
}
|
||||
},
|
||||
DefaultParameterName.JSON_SCHEMA: {
|
||||
'label': {
|
||||
'en_US': 'JSON Schema',
|
||||
},
|
||||
'type': 'text',
|
||||
'help': {
|
||||
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
|
||||
'zh_Hans': '设置返回的json schema,llm将按照它返回',
|
||||
},
|
||||
'required': False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -95,6 +95,7 @@ class DefaultParameterName(Enum):
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||
@ -118,6 +119,7 @@ class ParameterType(Enum):
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class ModelPropertyKey(Enum):
|
||||
|
||||
@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
@ -31,6 +31,14 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Base URL, 如:https://api.moonshot.cn/v1
|
||||
en_US: Base URL, e.g. https://api.moonshot.cn/v1
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-2024-08-06
|
||||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- gpt-4-turbo
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
model: chatgpt-4o-latest
|
||||
label:
|
||||
zh_Hans: chatgpt-4o-latest
|
||||
en_US: chatgpt-4o-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -37,6 +37,9 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
|
||||
@ -37,6 +37,9 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
@ -544,13 +545,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_object":
|
||||
response_format = {"type": "json_object"}
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not currect json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
|
||||
model_parameters["response_format"] = response_format
|
||||
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
@ -922,11 +928,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model.startswith('ft:'):
|
||||
model = model.split(':')[1]
|
||||
|
||||
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
||||
if model == "chatgpt-4o-latest":
|
||||
model = "gpt-4o"
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
@ -946,7 +955,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"See https://platform.openai.com/docs/advanced-usage/managing-tokens for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
model: Llama3-Chinese_v2
|
||||
label:
|
||||
en_US: Llama3-Chinese_v2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||
label:
|
||||
en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1024
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Meta-Llama-3-8B-Instruct
|
||||
label:
|
||||
en_US: Meta-Llama-3-8B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||
label:
|
||||
en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 410960
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Meta-Llama-3.1-8B-Instruct
|
||||
label:
|
||||
en_US: Meta-Llama-3.1-8B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.1
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -55,7 +55,8 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@ -55,7 +55,8 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@ -6,7 +6,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
context_size: 2048
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -55,7 +55,7 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
|
||||
@ -6,7 +6,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -55,7 +55,7 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
|
||||
@ -8,12 +8,12 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
context_size: 2048
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
@ -57,7 +57,7 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
model: Qwen2-72B-Instruct
|
||||
label:
|
||||
en_US: Qwen2-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -8,7 +8,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -57,7 +57,7 @@ parameter_rules:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
|
||||
@ -1,6 +1,15 @@
|
||||
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||
- Meta-Llama-3.1-8B-Instruct
|
||||
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||
- Meta-Llama-3-8B-Instruct
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-72B-Instruct
|
||||
- Qwen2-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- Qwen-14B-Chat-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen-14B-Chat-Int4
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- deepseek-v2-chat
|
||||
- deepseek-v2-lite-chat
|
||||
- Llama3-Chinese_v2
|
||||
- chatglm3-6b
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
model: chatglm3-6b
|
||||
label:
|
||||
en_US: chatglm3-6b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: deepseek-v2-chat
|
||||
label:
|
||||
en_US: deepseek-v2-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: deepseek-v2-lite-chat
|
||||
label:
|
||||
en_US: deepseek-v2-lite-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2048
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.5
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: "0.000"
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-large-en-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32768
|
||||
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-large-zh-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32768
|
||||
@ -0,0 +1,4 @@
|
||||
model: netease-youdao/bce-reranker-base_v1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 512
|
||||
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-reranker-v2-m3
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
@ -0,0 +1,87 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class SiliconflowRerankModel(RerankModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
|
||||
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
|
||||
if base_url.endswith('/'):
|
||||
base_url = base_url[:-1]
|
||||
try:
|
||||
response = httpx.post(
|
||||
base_url + '/rerank',
|
||||
json={
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n,
|
||||
"return_documents": True
|
||||
},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
}
|
||||
@ -12,10 +12,11 @@ help:
|
||||
en_US: Get your API Key from SiliconFlow
|
||||
zh_Hans: 从 SiliconFlow 获取 API Key
|
||||
url:
|
||||
en_US: https://cloud.siliconflow.cn/keys
|
||||
en_US: https://cloud.siliconflow.cn/account/ak
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
|
||||
@ -159,6 +159,8 @@ You should also complete the text started with ``` but not tell ``` directly.
|
||||
"""
|
||||
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
|
||||
model = model.replace('-chat', '')
|
||||
if model == 'farui-plus':
|
||||
model = 'qwen-farui-plus'
|
||||
|
||||
if model in self.tokenizers:
|
||||
tokenizer = self.tokenizers[model]
|
||||
|
||||
@ -1 +1 @@
|
||||
- soloar-1-mini-chat
|
||||
- solar-1-mini-chat
|
||||
|
||||
@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
|
||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
||||
get_model_config,
|
||||
get_v2_req_params,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
-> LLMResult | Generator:
|
||||
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
|
||||
req_params = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('req_params', {}).copy()
|
||||
if credentials.get('context_size'):
|
||||
req_params['max_prompt_tokens'] = credentials.get('context_size')
|
||||
if credentials.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = credentials.get('max_tokens')
|
||||
if model_parameters.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
|
||||
if model_parameters.get('temperature'):
|
||||
req_params['temperature'] = model_parameters.get('temperature')
|
||||
if model_parameters.get('top_p'):
|
||||
req_params['top_p'] = model_parameters.get('top_p')
|
||||
if model_parameters.get('top_k'):
|
||||
req_params['top_k'] = model_parameters.get('top_k')
|
||||
if model_parameters.get('presence_penalty'):
|
||||
req_params['presence_penalty'] = model_parameters.get(
|
||||
'presence_penalty')
|
||||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
req_params = get_v2_req_params(credentials, model_parameters, stop)
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
||||
]
|
||||
|
||||
resp = MaaSClient.wrap_exception(
|
||||
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||
if not stream:
|
||||
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
max_tokens = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
|
||||
if credentials.get('max_tokens'):
|
||||
max_tokens = int(credentials.get('max_tokens'))
|
||||
model_config = get_model_config(credentials)
|
||||
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
name='presence_penalty',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='presence_penalty',
|
||||
label={
|
||||
'en_US': 'Presence Penalty',
|
||||
'zh_Hans': '存在惩罚',
|
||||
},
|
||||
label=I18nObject(
|
||||
en_US='Presence Penalty',
|
||||
zh_Hans= '存在惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
),
|
||||
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
name='frequency_penalty',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='frequency_penalty',
|
||||
label={
|
||||
'en_US': 'Frequency Penalty',
|
||||
'zh_Hans': '频率惩罚',
|
||||
},
|
||||
label=I18nObject(
|
||||
en_US= 'Frequency Penalty',
|
||||
zh_Hans= '频率惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
),
|
||||
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=max_tokens,
|
||||
max=model_config.properties.max_tokens,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
]
|
||||
|
||||
model_properties = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||
if credentials.get('mode'):
|
||||
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
|
||||
model_features = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('features', [])
|
||||
|
||||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
model_type=ModelType.LLM,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=rules,
|
||||
features=model_features,
|
||||
features=model_config.features,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -1,181 +1,123 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Skylark2-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4000,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-8B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-70B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-8k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 16384,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 65536,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'GLM3-130B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'GLM3-130B-Fin': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Mistral-7B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 2048,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
}
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_tokens: int
|
||||
mode: LLMMode
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
features: list[ModelFeature]
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
'Doubao-pro-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Skylark2-pro-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Llama3-8B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Llama3-70B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-8k': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'GLM3-130B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'GLM3-130B-Fin': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Mistral-7B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
)
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = configs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_tokens=int(credentials.get('max_tokens', 0)),
|
||||
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
|
||||
),
|
||||
features=[]
|
||||
)
|
||||
return model_configs
|
||||
|
||||
|
||||
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
||||
stop: list[str] | None=None):
|
||||
req_params = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
if model_configs:
|
||||
req_params['max_prompt_tokens'] = model_configs.properties.context_size
|
||||
req_params['max_new_tokens'] = model_configs.properties.max_tokens
|
||||
|
||||
# model parameters
|
||||
if model_parameters.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
|
||||
if model_parameters.get('temperature'):
|
||||
req_params['temperature'] = model_parameters.get('temperature')
|
||||
if model_parameters.get('top_p'):
|
||||
req_params['top_p'] = model_parameters.get('top_p')
|
||||
if model_parameters.get('top_k'):
|
||||
req_params['top_k'] = model_parameters.get('top_k')
|
||||
if model_parameters.get('presence_penalty'):
|
||||
req_params['presence_penalty'] = model_parameters.get(
|
||||
'presence_penalty')
|
||||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
return req_params
|
||||
@ -1,9 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_chunks: int
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-embedding': {
|
||||
'req_params': {},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'max_chunks': 1,
|
||||
}
|
||||
},
|
||||
'Doubao-embedding': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_chunks=1)
|
||||
),
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = ModelConfigs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_chunks=int(credentials.get('max_chunks', 0)),
|
||||
)
|
||||
)
|
||||
return model_configs
|
||||
@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
|
||||
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
model_properties = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
if credentials.get('max_chunks'):
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
|
||||
credentials.get('max_chunks', 4096))
|
||||
model_config = get_model_config(credentials)
|
||||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
|
||||
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
@ -0,0 +1,195 @@
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
@staticmethod
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class _CommonWenxin:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['api_key'],
|
||||
"secret_key": credentials['secret_key']
|
||||
}
|
||||
return credentials_kwargs
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
||||
@ -1,102 +1,17 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
# map api_key to access_token
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
@ -120,51 +35,7 @@ class ErnieMessage:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class ErnieBotModel:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
class ErnieBotModel(_CommonWenxin):
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
||||
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
||||
@ -199,51 +70,6 @@ class ErnieBotModel:
|
||||
return self._handle_chat_stream_generate_response(resp)
|
||||
return self._handle_chat_generate_response(resp)
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
||||
|
||||
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
||||
return [ErnieMessage(message.content, message.role) for message in messages]
|
||||
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
|
||||
|
||||
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
return invoke_error_mapping()
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
model: embedding-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,184 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from json import dumps
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
invoke_error_mapping,
|
||||
)
|
||||
|
||||
|
||||
class TextEmbedding:
|
||||
@abstractmethod
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
access_token = self._get_access_token()
|
||||
url = f'{self.api_bases[model]}?access_token={access_token}'
|
||||
body = self._build_embed_request_body(model, texts, user)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
resp = post(url, data=dumps(body), headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
|
||||
return self._handle_embed_response(model, resp)
|
||||
|
||||
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
|
||||
if len(texts) == 0:
|
||||
raise BadRequestError('The number of texts should not be zero.')
|
||||
body = {
|
||||
'input': texts,
|
||||
'user_id': user,
|
||||
}
|
||||
return body
|
||||
|
||||
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||
data = response.json()
|
||||
if 'error_code' in data:
|
||||
code = data['error_code']
|
||||
msg = data['error_msg']
|
||||
# raise error
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = [v['embedding'] for v in data['data']]
|
||||
_usage = data['usage']
|
||||
tokens = _usage['prompt_tokens']
|
||||
total_tokens = _usage['total_tokens']
|
||||
|
||||
return embeddings, tokens, total_tokens
|
||||
|
||||
|
||||
class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return WenxinTextEmbedding(api_key, secret_key)
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||
user = user if user else 'ErnieBotDefault'
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
used_total_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
for i in _iter:
|
||||
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
|
||||
model,
|
||||
inputs[i: i + max_chunks],
|
||||
user)
|
||||
used_tokens += _used_tokens
|
||||
used_total_tokens += _total_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
|
||||
return TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return invoke_error_mapping()
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=total_tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -17,6 +17,7 @@ help:
|
||||
en_US: https://cloud.baidu.com/wenxin.html
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
tools=tools, stop=stop, stream=stream, user=user,
|
||||
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
)
|
||||
|
||||
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key')
|
||||
)
|
||||
if 'completion_type' not in credentials:
|
||||
if 'chat' in extra_param.model_ability:
|
||||
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key')
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
xinference_client = Client(
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_model = xinference_client.get_model(credentials['model_uid'])
|
||||
|
||||
@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||
api_key = credentials.get('api_key')
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=server_url,
|
||||
model_uid=model_uid,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
client = Client(
|
||||
base_url=server_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
|
||||
@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
if 'text-to-audio' not in extra_param.model_ability:
|
||||
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
try:
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
||||
api_key = credentials.get('api_key')
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
handle = RESTfulAudioModelHandle(
|
||||
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
|
||||
)
|
||||
|
||||
model_support_voice = [x.get("value") for x in
|
||||
self.get_tts_model_voices(model=model, credentials=credentials)]
|
||||
|
||||
@ -35,13 +35,13 @@ cache_lock = Lock()
|
||||
|
||||
class XinferenceHelper:
|
||||
@staticmethod
|
||||
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
||||
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
XinferenceHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_uid not in cache:
|
||||
cache[model_uid] = {
|
||||
'expires': time() + 300,
|
||||
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
|
||||
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
|
||||
}
|
||||
return cache[model_uid]['value']
|
||||
|
||||
@ -56,7 +56,7 @@ class XinferenceHelper:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
"""
|
||||
get xinference model extra parameter like model_format and model_handle_type
|
||||
"""
|
||||
@ -70,9 +70,10 @@ class XinferenceHelper:
|
||||
session = Session()
|
||||
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
response = session.get(url, headers=headers, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
||||
if response.status_code != 200:
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
|
||||
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
|
||||
@ -811,7 +811,7 @@ class ProviderManager:
|
||||
-> list[ModelSettings]:
|
||||
"""
|
||||
Convert to model settings.
|
||||
|
||||
:param provider_entity: provider entity
|
||||
:param provider_model_settings: provider model settings include enabled, load balancing enabled
|
||||
:param load_balancing_model_configs: load balancing model configs
|
||||
:return:
|
||||
|
||||
@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:param save_as: save as
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:param save_as: save as
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:param save_as: save as
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: meta
|
||||
:param save_as: save as
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
||||
@ -29,6 +29,6 @@ class GitlabProvider(BuiltinToolProviderController):
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError((response.json()).get('message'))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -2,37 +2,37 @@ identity:
|
||||
author: Leo.Wang
|
||||
name: gitlab
|
||||
label:
|
||||
en_US: Gitlab
|
||||
zh_Hans: Gitlab
|
||||
en_US: GitLab
|
||||
zh_Hans: GitLab
|
||||
description:
|
||||
en_US: Gitlab plugin for commit
|
||||
zh_Hans: 用于获取Gitlab commit的插件
|
||||
en_US: GitLab plugin, API v4 only.
|
||||
zh_Hans: 用于获取GitLab内容的插件,目前仅支持 API v4。
|
||||
icon: gitlab.svg
|
||||
credentials_for_provider:
|
||||
access_tokens:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Gitlab access token
|
||||
zh_Hans: Gitlab access token
|
||||
en_US: GitLab access token
|
||||
zh_Hans: GitLab access token
|
||||
placeholder:
|
||||
en_US: Please input your Gitlab access token
|
||||
zh_Hans: 请输入你的 Gitlab access token
|
||||
en_US: Please input your GitLab access token
|
||||
zh_Hans: 请输入你的 GitLab access token
|
||||
help:
|
||||
en_US: Get your Gitlab access token from Gitlab
|
||||
zh_Hans: 从 Gitlab 获取您的 access token
|
||||
en_US: Get your GitLab access token from GitLab
|
||||
zh_Hans: 从 GitLab 获取您的 access token
|
||||
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
|
||||
site_url:
|
||||
type: text-input
|
||||
required: false
|
||||
default: 'https://gitlab.com'
|
||||
label:
|
||||
en_US: Gitlab site url
|
||||
zh_Hans: Gitlab site url
|
||||
en_US: GitLab site url
|
||||
zh_Hans: GitLab site url
|
||||
placeholder:
|
||||
en_US: Please input your Gitlab site url
|
||||
zh_Hans: 请输入你的 Gitlab site url
|
||||
en_US: Please input your GitLab site url
|
||||
zh_Hans: 请输入你的 GitLab site url
|
||||
help:
|
||||
en_US: Find your Gitlab url
|
||||
zh_Hans: 找到你的Gitlab url
|
||||
en_US: Find your GitLab url
|
||||
zh_Hans: 找到你的 GitLab url
|
||||
url: https://gitlab.com/help
|
||||
|
||||
@ -18,6 +18,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
employee = tool_parameters.get('employee', '')
|
||||
start_time = tool_parameters.get('start_time', '')
|
||||
end_time = tool_parameters.get('end_time', '')
|
||||
change_type = tool_parameters.get('change_type', 'all')
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
@ -36,11 +37,11 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
|
||||
|
||||
return self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
|
||||
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
@ -74,7 +75,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit['id']
|
||||
print(f"\tCommit SHA: {commit_sha}")
|
||||
author_name = commit['author_name']
|
||||
|
||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
@ -87,14 +88,23 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
removed_lines = diff['diff'].count('\n-')
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
if total_changes > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||
results.append({
|
||||
"project": project_name,
|
||||
"commit_sha": commit_sha,
|
||||
"diff": final_code
|
||||
})
|
||||
print(f"Commit code:{final_code}")
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code
|
||||
})
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code_escaped
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
|
||||
@ -2,24 +2,24 @@ identity:
|
||||
name: gitlab_commits
|
||||
author: Leo.Wang
|
||||
label:
|
||||
en_US: Gitlab Commits
|
||||
zh_Hans: Gitlab代码提交内容
|
||||
en_US: GitLab Commits
|
||||
zh_Hans: GitLab 提交内容查询
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query gitlab commits. Input should be a exists username.
|
||||
zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||
llm: A tool for query gitlab commits. Input should be a exists username or project.
|
||||
en_US: A tool for query GitLab commits, Input should be a exists username or projec.
|
||||
zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||
llm: A tool for query GitLab commits, Input should be a exists username or project.
|
||||
parameters:
|
||||
- name: employee
|
||||
- name: username
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: employee
|
||||
en_US: username
|
||||
zh_Hans: 员工用户名
|
||||
human_description:
|
||||
en_US: employee
|
||||
en_US: username
|
||||
zh_Hans: 员工用户名
|
||||
llm_description: employee for gitlab
|
||||
llm_description: User name for GitLab
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
@ -30,7 +30,7 @@ parameters:
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
llm_description: project for gitlab
|
||||
llm_description: project for GitLab
|
||||
form: llm
|
||||
- name: start_time
|
||||
type: string
|
||||
@ -41,7 +41,7 @@ parameters:
|
||||
human_description:
|
||||
en_US: start_time
|
||||
zh_Hans: 开始时间
|
||||
llm_description: start_time for gitlab
|
||||
llm_description: Start time for GitLab
|
||||
form: llm
|
||||
- name: end_time
|
||||
type: string
|
||||
@ -52,5 +52,26 @@ parameters:
|
||||
human_description:
|
||||
en_US: end_time
|
||||
zh_Hans: 结束时间
|
||||
llm_description: end_time for gitlab
|
||||
llm_description: End time for GitLab
|
||||
form: llm
|
||||
- name: change_type
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: all
|
||||
label:
|
||||
en_US: all
|
||||
zh_Hans: 所有
|
||||
- value: new
|
||||
label:
|
||||
en_US: new
|
||||
zh_Hans: 新增
|
||||
default: all
|
||||
label:
|
||||
en_US: change_type
|
||||
zh_Hans: 变更类型
|
||||
human_description:
|
||||
en_US: change_type
|
||||
zh_Hans: 变更类型
|
||||
llm_description: Content change type for GitLab
|
||||
form: llm
|
||||
|
||||
95
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py
Normal file
95
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GitlabFilesTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
project = tool_parameters.get('project', '')
|
||||
branch = tool_parameters.get('branch', '')
|
||||
path = tool_parameters.get('path', '')
|
||||
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
if not branch:
|
||||
return self.create_text_message('Branch is required')
|
||||
|
||||
if not path:
|
||||
return self.create_text_message('Path is required')
|
||||
|
||||
access_token = self.runtime.credentials.get('access_tokens')
|
||||
site_url = self.runtime.credentials.get('site_url')
|
||||
|
||||
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||
return self.create_text_message("Gitlab API Access Tokens is required.")
|
||||
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, project)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{project}' not found.")
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) < 2:
|
||||
return None, None
|
||||
return parts[0], parts[1]
|
||||
|
||||
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
try:
|
||||
url = f"{site_url}/api/v4/projects?search={project_name}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
for project in projects:
|
||||
if project['name'] == project_name:
|
||||
return project['id']
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching project ID from GitLab: {e}")
|
||||
return None
|
||||
|
||||
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# List files and directories in the given path
|
||||
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item['path']
|
||||
if item['type'] == 'tree': # It's a directory
|
||||
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
|
||||
else: # It's a file
|
||||
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({
|
||||
"path": item_path,
|
||||
"branch": branch,
|
||||
"content": file_content
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
@ -0,0 +1,45 @@
|
||||
identity:
|
||||
name: gitlab_files
|
||||
author: Leo.Wang
|
||||
label:
|
||||
en_US: GitLab Files
|
||||
zh_Hans: GitLab 文件获取
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path.
|
||||
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
||||
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
||||
parameters:
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
llm_description: Project for GitLab
|
||||
form: llm
|
||||
- name: branch
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: branch
|
||||
zh_Hans: 分支
|
||||
human_description:
|
||||
en_US: branch
|
||||
zh_Hans: 分支
|
||||
llm_description: Branch for GitLab
|
||||
form: llm
|
||||
- name: path
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
human_description:
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
llm_description: File path for GitLab
|
||||
form: llm
|
||||
43
api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py
Normal file
43
api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import Any
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class JinaTokenizerTool(BuiltinTool):
|
||||
_jina_tokenizer_endpoint = 'https://tokenize.jina.ai/'
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
content = tool_parameters['content']
|
||||
body = {
|
||||
"content": content
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
|
||||
headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
|
||||
|
||||
if tool_parameters.get('return_chunks', False):
|
||||
body['return_chunks'] = True
|
||||
|
||||
if tool_parameters.get('return_tokens', False):
|
||||
body['return_tokens'] = True
|
||||
|
||||
if tokenizer := tool_parameters.get('tokenizer'):
|
||||
body['tokenizer'] = tokenizer
|
||||
|
||||
response = ssrf_proxy.post(
|
||||
self._jina_tokenizer_endpoint,
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
return self.create_json_message(response.json())
|
||||
@ -0,0 +1,64 @@
|
||||
identity:
|
||||
name: jina_tokenizer
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: JinaTokenizer
|
||||
description:
|
||||
human:
|
||||
en_US: Free API to tokenize text and segment long text into chunks.
|
||||
zh_Hans: 免费的API可以将文本tokenize,也可以将长文本分割成多个部分。
|
||||
llm: Free API to tokenize text and segment long text into chunks.
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content
|
||||
zh_Hans: 内容
|
||||
llm_description: the content which need to tokenize or segment
|
||||
form: llm
|
||||
- name: return_tokens
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Return the tokens
|
||||
zh_Hans: 是否返回tokens
|
||||
human_description:
|
||||
en_US: Return the tokens and their corresponding ids in the response.
|
||||
zh_Hans: 返回tokens及其对应的ids。
|
||||
form: form
|
||||
- name: return_chunks
|
||||
type: boolean
|
||||
label:
|
||||
en_US: Return the chunks
|
||||
zh_Hans: 是否分块
|
||||
human_description:
|
||||
en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues.
|
||||
zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。
|
||||
form: form
|
||||
- name: tokenizer
|
||||
type: select
|
||||
options:
|
||||
- value: cl100k_base
|
||||
label:
|
||||
en_US: cl100k_base
|
||||
- value: o200k_base
|
||||
label:
|
||||
en_US: o200k_base
|
||||
- value: p50k_base
|
||||
label:
|
||||
en_US: p50k_base
|
||||
- value: r50k_base
|
||||
label:
|
||||
en_US: r50k_base
|
||||
- value: p50k_edit
|
||||
label:
|
||||
en_US: p50k_edit
|
||||
- value: gpt2
|
||||
label:
|
||||
en_US: gpt2
|
||||
label:
|
||||
en_US: Tokenizer
|
||||
human_description:
|
||||
en_US: cl100k_base - gpt-4,gpt-3.5-turbo,gpt-3.5; o200k_base - gpt-4o,gpt-4o-mini; p50k_base - text-davinci-003,text-davinci-002
|
||||
form: form
|
||||
@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
identity: Optional[ToolIdentity] = None
|
||||
@ -76,7 +78,7 @@ class Tool(BaseModel, ABC):
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
@ -84,7 +86,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
@ -99,7 +101,7 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
|
||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||
|
||||
def set_text_variable(self, variable_name: str, text: str) -> None:
|
||||
@ -108,9 +110,9 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
|
||||
self.variables.set_text(self.identity.name, variable_name, text)
|
||||
|
||||
|
||||
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
|
||||
"""
|
||||
get a variable
|
||||
@ -120,14 +122,14 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
|
||||
if isinstance(name, Enum):
|
||||
name = name.value
|
||||
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name == name:
|
||||
return variable
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
|
||||
@ -138,9 +140,9 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
|
||||
return self.get_variable(self.VARIABLE_KEY.IMAGE)
|
||||
|
||||
|
||||
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
||||
"""
|
||||
get a variable file
|
||||
@ -151,7 +153,7 @@ class Tool(BaseModel, ABC):
|
||||
variable = self.get_variable(name)
|
||||
if not variable:
|
||||
return None
|
||||
|
||||
|
||||
if not isinstance(variable, ToolRuntimeImageVariable):
|
||||
return None
|
||||
|
||||
@ -160,9 +162,9 @@ class Tool(BaseModel, ABC):
|
||||
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
|
||||
if not file_binary:
|
||||
return None
|
||||
|
||||
|
||||
return file_binary[0]
|
||||
|
||||
|
||||
def list_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all variables
|
||||
@ -171,9 +173,9 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
|
||||
return self.variables.pool
|
||||
|
||||
|
||||
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all image variables
|
||||
@ -182,9 +184,9 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
|
||||
result.append(variable)
|
||||
@ -225,7 +227,7 @@ class Tool(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
pass
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials
|
||||
@ -244,7 +246,7 @@ class Tool(BaseModel, ABC):
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.parameters or []
|
||||
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get all runtime parameters
|
||||
@ -278,7 +280,7 @@ class Tool(BaseModel, ABC):
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
@ -286,18 +288,18 @@ class Tool(BaseModel, ABC):
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=image,
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
|
||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message='',
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
@ -305,10 +307,10 @@ class Tool(BaseModel, ABC):
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=link,
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=link,
|
||||
save_as=save_as)
|
||||
|
||||
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
@ -321,7 +323,7 @@ class Tool(BaseModel, ABC):
|
||||
message=text,
|
||||
save_as=save_as
|
||||
)
|
||||
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.file_obj import FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
@ -27,12 +27,12 @@ class ToolFileMessageTransformer:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_url=message.message
|
||||
)
|
||||
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
@ -55,14 +55,14 @@ class ToolFileMessageTransformer:
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
|
||||
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
@ -81,7 +81,7 @@ class ToolFileMessageTransformer:
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
file_var = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
@ -103,7 +103,7 @@ class ToolFileMessageTransformer:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
|
||||
@ -4,13 +4,14 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
Node Types.
|
||||
"""
|
||||
|
||||
START = 'start'
|
||||
END = 'end'
|
||||
ANSWER = 'answer'
|
||||
@ -44,33 +45,11 @@ class NodeType(Enum):
|
||||
raise ValueError(f'invalid node type value {value}')
|
||||
|
||||
|
||||
class SystemVariable(Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'SystemVariable':
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
||||
|
||||
|
||||
class NodeRunMetadataKey(Enum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = 'total_tokens'
|
||||
TOTAL_PRICE = 'total_price'
|
||||
CURRENCY = 'currency'
|
||||
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
|
||||
@ -6,7 +6,7 @@ from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user