Compare commits

..

2 Commits

Author SHA1 Message Date
3bfb26d571 fix model runtime quota 2024-08-12 17:09:09 +08:00
ccf4bd8555 fix model runtime quota 2024-08-12 15:57:53 +08:00
437 changed files with 3959 additions and 12847 deletions

View File

@ -76,7 +76,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@ -90,6 +90,5 @@ jobs:
pgvecto-rs
pgvector
chroma
elasticsearch
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -6,6 +6,5 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs."

View File

@ -45,10 +45,6 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

View File

@ -152,7 +152,7 @@ Nhanh chóng chạy Dify trong môi trường của bạn với [hướng dẫn
Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn.
- **Dify cho doanh nghiệp / tổ chức</br>**
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Lên lịch cuộc họp với chúng tôi](https://cal.com/guchenhe/30min) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
> Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh.
@ -221,6 +221,23 @@ Triển khai Dify lên Azure chỉ với một cú nhấp chuột bằng cách s
* [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng.
* [Twitter](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng.
Hoặc, lên lịch cuộc họp trực tiếp với một thành viên trong nhóm:
<table>
<tr>
<th>Điểm liên hệ</th>
<th>Mục đích</th>
</tr>
<tr>
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Yêu cầu kinh doanh & phản hồi sản phẩm</td>
</tr>
<tr>
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Đóng góp, vấn đề & yêu cầu tính năng</td>
</tr>
</table>
## Lịch sử Yêu thích
[![Biểu đồ Lịch sử Yêu thích](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)

View File

@ -130,12 +130,6 @@ TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431

View File

@ -1,6 +1,6 @@
import os
if os.environ.get("DEBUG", "false").lower() != "true":
if os.environ.get("DEBUG", "false").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
if os.name == "nt":
os.system('tzutil /s "UTC"')
else:
os.environ["TZ"] = "UTC"
os.environ['TZ'] = 'UTC'
time.tzset()
@ -70,14 +70,13 @@ class DifyApp(Flask):
# -------------
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
@ -93,7 +92,7 @@ def create_flask_app_with_configs() -> Flask:
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ""
os.environ[key] = ''
return dify_app
@ -101,10 +100,10 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config["SECRET_KEY"]
app.secret_key = app.config['SECRET_KEY']
log_handlers = None
log_file = app.config.get("LOG_FILE")
log_file = app.config.get('LOG_FILE')
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
@ -112,24 +111,23 @@ def create_app() -> Flask:
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5,
backupCount=5
),
logging.StreamHandler(sys.stdout),
logging.StreamHandler(sys.stdout)
]
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
level=app.config.get('LOG_LEVEL'),
format=app.config.get('LOG_FORMAT'),
datefmt=app.config.get('LOG_DATEFORMAT'),
handlers=log_handlers,
force=True,
force=True
)
log_tz = app.config.get("LOG_TZ")
log_tz = app.config.get('LOG_TZ')
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
@ -164,24 +162,24 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]:
if request.blueprint not in ['console', 'inner_api']:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
auth_header = request.headers.get('Authorization', '')
if not auth_header:
auth_token = request.args.get("_token")
auth_token = request.args.get('_token')
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
raise Unauthorized('Invalid Authorization token.')
else:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
user_id = decoded.get('user_id')
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account:
@ -192,11 +190,10 @@ def load_user_from_request(request_from_flask_login):
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
return Response(json.dumps({
'code': 'unauthorized',
'message': "Unauthorized."
}), status=401, content_type="application/json")
# register blueprint routers
@ -207,36 +204,38 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(service_api_bp)
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
CORS(web_bp,
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
app.register_blueprint(web_bp)
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
CORS(console_app_bp,
resources={
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
app.register_blueprint(console_app_bp)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
@ -246,29 +245,29 @@ def register_blueprints(app):
app = create_app()
celery = app.extensions["celery"]
if app.config.get("TESTING"):
if app.config.get('TESTING'):
print("App is running in TESTING mode")
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response
@app.route("/health")
@app.route('/health')
def health():
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
status=200,
content_type="application/json",
)
return Response(json.dumps({
'pid': os.getpid(),
'status': 'ok',
'version': app.config['CURRENT_VERSION']
}), status=200, content_type="application/json")
@app.route("/threads")
@app.route('/threads')
def threads():
num_threads = threading.active_count()
threads = threading.enumerate()
@ -279,34 +278,32 @@ def threads():
thread_id = thread.ident
is_alive = thread.is_alive()
thread_list.append(
{
"name": thread_name,
"id": thread_id,
"is_alive": is_alive,
}
)
thread_list.append({
'name': thread_name,
'id': thread_id,
'is_alive': is_alive
})
return {
"pid": os.getpid(),
"thread_num": num_threads,
"threads": thread_list,
'pid': os.getpid(),
'thread_num': num_threads,
'threads': thread_list
}
@app.route("/db-pool-stat")
@app.route('/db-pool-stat')
def pool_stat():
engine = db.engine
return {
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
'pid': os.getpid(),
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
}
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)

View File

@ -27,29 +27,32 @@ from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
@click.command('reset-password', help='Reset the account password.')
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
@click.option('--new-password', prompt=True, help='the new password.')
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
if not account:
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
return
try:
valid_password(new_password)
except:
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
click.echo(
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
return
# generate password salt
@ -62,87 +65,80 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
@click.command('reset-email', help='Reset the account email.')
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
@click.option('--new-email', prompt=True, help='the new email.')
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
if not account:
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
return
try:
email_validate(new_email)
except:
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
click.echo(
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
@click.command(
"reset-encrypt-key-pair",
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
"After the reset, all LLM credentials will become invalid, "
"requiring re-entry."
"Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
'After the reset, all LLM credentials will become invalid, '
'requiring re-entry.'
'Only support SELF_HOSTED mode.')
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
' this operation cannot be rolled back!', fg='red'))
def reset_encrypt_key_pair():
"""
Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(
click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
@click.command('vdb-migrate', help='migrate vector db.')
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]:
if scope in ['knowledge', 'all']:
migrate_knowledge_vector_database()
if scope in ["annotation", "all"]:
if scope in ['annotation', 'all']:
migrate_annotation_vector_database()
@ -150,7 +146,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style("Start migrate annotation data.", fg="green"))
click.echo(click.style('Start migrate annotation data.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
@ -158,103 +154,98 @@ def migrate_annotation_vector_database():
while True:
try:
# get apps info
apps = (
db.session.query(App)
.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo("App annotation setting is disabled: {}".format(app.id))
click.echo('App annotation setting is disabled: {}'.format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
if not dataset_collection_binding:
click.echo("App annotation collection binding is not exist: {}".format(app.id))
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique="high_quality",
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
collection_binding_id=dataset_collection_binding.id
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
)
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
except Exception as e:
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
raise e
if documents:
try:
click.echo(
click.style(
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
raise e
click.echo(f"Successfully migrated app annotation {app.id}.")
click.echo(f'Successfully migrated app annotation {app.id}.')
create_count += 1
except Exception as e:
click.echo(
click.style(
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style("Start migrate vector db.", fg="green"))
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
@ -262,77 +253,87 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = (
db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo("Create dataset vdb index: {}".format(dataset.id))
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict["type"] == vector_type:
if dataset.index_struct_dict['type'] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ""
collection_name = ''
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError("Dataset Collection Bindings is not exist!")
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
index_struct_dict = {
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name},
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
@ -340,14 +341,9 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name},
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -357,41 +353,29 @@ def migrate_knowledge_vector_database():
try:
vector.delete()
click.echo(
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
raise e
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
@ -401,7 +385,7 @@ def migrate_knowledge_vector_database():
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
}
)
documents.append(document)
@ -409,43 +393,37 @@ def migrate_knowledge_vector_database():
if documents:
try:
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
raise e
db.session.add(dataset)
db.session.commit()
click.echo(f"Successfully migrated dataset {dataset.id}.")
click.echo(f'Successfully migrated dataset {dataset.id}.')
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style("Start convert to agent apps.", fg="green"))
click.echo(click.style('Start convert to agent apps.', fg='green'))
proceeded_app_ids = []
@ -480,7 +458,7 @@ def convert_to_agent_apps():
break
for app in apps:
click.echo("Converting app: {}".format(app.id))
click.echo('Converting app: {}'.format(app.id))
try:
app.mode = AppMode.AGENT_CHAT.value
@ -492,139 +470,137 @@ def convert_to_agent_apps():
)
db.session.commit()
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
except Exception as e:
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
def add_qdrant_doc_id_index(field: str):
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant url is required.")
raise ValueError('Qdrant url is required.')
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
client.create_payload_index(binding.collection_name, field,
field_schema=PayloadSchemaType.KEYWORD)
create_count += 1
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(
click.style(
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
except Exception as e:
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(click.style('Failed to create qdrant client.', fg='red'))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
click.echo(
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
@click.command('create-tenant', help='Create account and tenant.')
@click.option('--email', prompt=True, help='The email address of the tenant account.')
@click.option('--language', prompt=True, help='Account language, default: en-US.')
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style("Sorry, email is required.", fg="red"))
click.echo(click.style('Sorry, email is required.', fg='red'))
return
# Create account
email = email.strip()
if "@" not in email:
click.echo(click.style("Sorry, invalid email address.", fg="red"))
if '@' not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red'))
return
account_name = email.split("@")[0]
account_name = email.split('@')[0]
if language not in languages:
language = "en-US"
language = 'en-US'
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language
)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(
click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
click.echo(click.style('Congratulations! Account and tenant created.\n'
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
@click.command("upgrade-db", help="upgrade the database")
@click.command('upgrade-db', help='upgrade the database')
def upgrade_db():
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
click.echo('Preparing database migration...')
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style("Start database migration.", fg="green"))
click.echo(click.style('Start database migration.', fg='green'))
# run db migration
import flask_migrate
flask_migrate.upgrade()
click.echo(click.style("Database migration successful!", fg="green"))
click.echo(click.style('Database migration successful!', fg='green'))
except Exception as e:
logging.exception(f"Database migration failed, error: {e}")
logging.exception(f'Database migration failed, error: {e}')
finally:
lock.release()
else:
click.echo("Database migration skipped")
click.echo('Database migration skipped')
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
failed_app_ids = []
while True:
@ -655,14 +631,15 @@ where sites.id is null limit 1000"""
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
logging.exception(f'Fix app related site missing issue failed, error: {e}')
continue
if not processed_count:
break
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
def register_commands(app):

View File

@ -12,14 +12,19 @@ from configs.packaging import PackagingInfo
class DifyConfig(
# Packaging info
PackagingInfo,
# Deployment configs
DeploymentConfig,
# Feature configs
FeatureConfig,
# Middleware configs
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@ -31,6 +36,7 @@ class DifyConfig(
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra='ignore',
)
@ -61,5 +67,3 @@ class DifyConfig(
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.0',
default='0.6.16',
)
COMMIT_SHA: str = Field(

View File

@ -1 +1 @@
HIDDEN_VALUE = "[__HIDDEN__]"
HIDDEN_VALUE = '[__HIDDEN__]'

View File

@ -1,22 +1,22 @@
language_timezone_mapping = {
"en-US": "America/New_York",
"zh-Hans": "Asia/Shanghai",
"zh-Hant": "Asia/Taipei",
"pt-BR": "America/Sao_Paulo",
"es-ES": "Europe/Madrid",
"fr-FR": "Europe/Paris",
"de-DE": "Europe/Berlin",
"ja-JP": "Asia/Tokyo",
"ko-KR": "Asia/Seoul",
"ru-RU": "Europe/Moscow",
"it-IT": "Europe/Rome",
"uk-UA": "Europe/Kyiv",
"vi-VN": "Asia/Ho_Chi_Minh",
"ro-RO": "Europe/Bucharest",
"pl-PL": "Europe/Warsaw",
"hi-IN": "Asia/Kolkata",
"tr-TR": "Europe/Istanbul",
"fa-IR": "Asia/Tehran",
'en-US': 'America/New_York',
'zh-Hans': 'Asia/Shanghai',
'zh-Hant': 'Asia/Taipei',
'pt-BR': 'America/Sao_Paulo',
'es-ES': 'Europe/Madrid',
'fr-FR': 'Europe/Paris',
'de-DE': 'Europe/Berlin',
'ja-JP': 'Asia/Tokyo',
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata',
'tr-TR': 'Europe/Istanbul',
'fa-IR': 'Asia/Tehran',
}
languages = list(language_timezone_mapping.keys())
@ -26,5 +26,6 @@ def supported_language(lang):
if lang in languages:
return lang
error = "{lang} is not a valid language.".format(lang=lang)
error = ('{lang} is not a valid language.'
.format(lang=lang))
raise ValueError(error)

View File

@ -5,79 +5,82 @@ from models.model import AppMode
default_app_templates = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
"mode": AppMode.WORKFLOW.value,
"enable_site": True,
"enable_api": True,
'app': {
'mode': AppMode.WORKFLOW.value,
'enable_site': True,
'enable_api': True
}
},
# completion default mode
AppMode.COMPLETION: {
"app": {
"mode": AppMode.COMPLETION.value,
"enable_site": True,
"enable_api": True,
'app': {
'mode': AppMode.COMPLETION.value,
'enable_site': True,
'enable_api': True
},
"model_config": {
"model": {
'model_config': {
'model': {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
"completion_params": {}
},
"user_input_form": json.dumps(
[
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": "",
},
},
]
),
"pre_prompt": "{{query}}",
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
},
},
# chat default mode
AppMode.CHAT: {
"app": {
"mode": AppMode.CHAT.value,
"enable_site": True,
"enable_api": True,
'app': {
'mode': AppMode.CHAT.value,
'enable_site': True,
'enable_api': True
},
"model_config": {
"model": {
'model_config': {
'model': {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
},
"completion_params": {}
}
}
},
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"enable_site": True,
"enable_api": True,
},
'app': {
'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True,
'enable_api': True
}
},
# agent-chat default mode
AppMode.AGENT_CHAT: {
"app": {
"mode": AppMode.AGENT_CHAT.value,
"enable_site": True,
"enable_api": True,
'app': {
'mode': AppMode.AGENT_CHAT.value,
'enable_site': True,
'enable_api': True
},
"model_config": {
"model": {
'model_config': {
'model': {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
},
},
"completion_params": {}
}
}
}
}

View File

@ -1,7 +1,3 @@
from contextvars import ContextVar
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
tenant_id: ContextVar[str] = ContextVar('tenant_id')

View File

@ -17,7 +17,6 @@ from .app import (
audio,
completion,
conversation,
conversation_variables,
generator,
message,
model_config,

View File

@ -61,7 +61,6 @@ class AppListApi(Resource):
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -95,7 +94,6 @@ class AppImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -169,7 +167,6 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
@ -211,7 +208,6 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()

View File

@ -33,7 +33,7 @@ class CompletionConversationApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
def get(self, app_model):
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
conversation_id = str(conversation_id)
@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
conversation_id = str(conversation_id)
@ -256,7 +256,7 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
conversation_id = str(conversation_id)

View File

@ -1,61 +0,0 @@
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
class ConversationVariablesApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', type=str, location='args')
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
else:
raise ValueError('conversation_id is required')
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {
'page': page,
'limit': page_size,
'total': len(rows),
'has_more': False,
'data': [
{
'created_at': row.created_at,
'updated_at': row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')

View File

@ -16,7 +16,6 @@ from models.model import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json')
@ -54,7 +53,6 @@ class AppSite(Resource):
for attr_name in [
'title',
'icon_type',
'icon',
'icon_background',
'description',

View File

@ -74,7 +74,6 @@ class DraftWorkflowApi(Resource):
parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument('conversation_variables', type=list, required=False, location='json')
args = parser.parse_args()
elif 'text/plain' in content_type:
try:
@ -89,8 +88,7 @@ class DraftWorkflowApi(Resource):
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables'),
'conversation_variables': data.get('conversation_variables'),
'environment_variables': data.get('environment_variables')
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
@ -102,8 +100,6 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args['graph'],
@ -111,7 +107,6 @@ class DraftWorkflowApi(Resource):
unique_hash=args.get('hash'),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -459,7 +454,6 @@ class ConvertToWorkflowApi(Resource):
if request.data:
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()

View File

@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -178,20 +178,11 @@ class DatasetDocumentListApi(Resource):
.subquery()
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
.order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
elif sort == 'created_at':
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
query = query.order_by(sort_logic(Document.created_at))
else:
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -131,7 +131,7 @@ class MessageSuggestedApi(Resource):
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
except SuggestedQuestionsAfterAnswerDisabledError:
raise BadRequest("Suggested Questions Is Disabled.")
raise BadRequest("Message Not Exists.")
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -53,22 +53,19 @@ class SegmentApi(DatasetApiResource):
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
if args['segments'] is not None:
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
else:
return {"error": "Segemtns is required"}, 400
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@ -6,7 +6,6 @@ from configs import dify_config
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus
from models.model import Site
from services.feature_service import FeatureService
@ -29,10 +28,8 @@ class AppSiteApi(WebApiResource):
'title': fields.String,
'chat_color_theme': fields.String,
'chat_color_theme_inverted': fields.Boolean,
'icon_type': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'icon_url': AppIconUrlField,
'description': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,

View File

@ -64,19 +64,15 @@ class BaseAgentRunner(AppRunner):
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@ -449,7 +445,7 @@ class BaseAgentRunner(AppRunner):
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
tool_responses = { tool: agent_thought.observation for tool in tools }
for tool in tools:
# generate a uuid for tool call

View File

@ -292,8 +292,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool

View File

@ -93,7 +93,6 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"],
)
)

View File

@ -3,9 +3,8 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.file.file_obj import FileExtraConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models import AppMode
from models.model import AppMode
class ModelConfigEntity(BaseModel):
@ -201,6 +200,11 @@ class TracingConfigEntity(BaseModel):
tracing_provider: str
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AppAdditionalFeatures(BaseModel):

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.file.file_obj import FileExtraConfig
from core.app.app_config.entities import FileExtraConfig
class FileUploadConfigManager:

View File

@ -8,8 +8,6 @@ from typing import Union
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -120,6 +113,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
@ -127,7 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
@ -147,10 +141,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
@ -186,6 +180,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@ -194,12 +189,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream
)
def _generate(self, *,
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
conversation: Conversation = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
@ -216,7 +211,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
# db.session.refresh(conversation)
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -228,69 +223,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
'user': user,
'context': contextvars.copy_context()
})
worker_thread.start()
@ -303,7 +244,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=stream
)
return AdvancedChatAppGenerateResponseConverter.convert(
@ -314,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
@ -341,7 +284,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=application_generate_entity.user_id
)
else:
# get message
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
@ -349,6 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
@ -371,17 +316,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -401,7 +343,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=stream
)
try:

View File

@ -16,10 +16,12 @@ from core.app.entities.app_invoke_entities import (
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models import App, Message, Workflow
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -29,12 +31,10 @@ class AdvancedChatAppRunner(AppRunner):
AdvancedChat Application Runner
"""
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message,
) -> None:
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -48,43 +48,53 @@ class AdvancedChatAppRunner(AppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
@ -96,29 +106,43 @@ class AdvancedChatAppRunner(AppRunner):
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@ -126,25 +150,22 @@ class AdvancedChatAppRunner(AppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def handle_input_moderation(
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
@ -171,20 +192,17 @@ class AdvancedChatAppRunner(AppRunner):
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
@ -199,27 +217,29 @@ class AdvancedChatAppRunner(AppRunner):
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from,
invoke_from=app_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
@ -230,10 +250,21 @@ class AdvancedChatAppRunner(AppRunner):
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)

View File

@ -4,7 +4,6 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -48,8 +47,7 @@ from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@ -73,7 +71,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
@ -84,7 +81,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
stream: bool
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -106,12 +103,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariable.USER_ID: user_id
}
self._task_state = AdvancedChatTaskState(
@ -249,7 +245,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and hasattr(message.event, 'metadata')
and message.event.metadata
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
@ -616,9 +613,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from typing import Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -14,6 +14,7 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -26,16 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list[FileVar],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
@ -128,7 +126,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list[FileVar],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
@ -256,7 +254,6 @@ class AppRunner:
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:return:
"""
if not stream:
@ -279,7 +276,6 @@ class AppRunner:
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
queue_manager.publish(
@ -295,7 +291,6 @@ class AppRunner:
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
model = None
@ -371,7 +366,7 @@ class AppRunner:
message_id=message_id,
trace_manager=app_generate_entity.trace_manager
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
@ -423,7 +418,7 @@ class AppRunner:
inputs=inputs,
query=query
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,

View File

@ -138,7 +138,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
Initialize generate records
:param application_generate_entity: application generate entity
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
@ -259,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction
def _get_conversation(self, conversation_id: str):
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -271,9 +270,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
def _get_message(self, message_id: str) -> Message:

View File

@ -11,8 +11,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -27,7 +26,8 @@ class WorkflowAppRunner:
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@ -47,36 +47,25 @@ class WorkflowAppRunner:
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@ -86,33 +75,44 @@ class WorkflowAppRunner:
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@ -120,13 +120,11 @@ class WorkflowAppRunner:
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow

View File

@ -42,8 +42,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@ -520,7 +519,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
@ -531,3 +530,4 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -1,7 +1,7 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -12,9 +12,11 @@ from .segments import (
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
@ -29,6 +31,7 @@ __all__ = [
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
@ -41,9 +44,10 @@ __all__ = [
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'FileSegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArraySegment',
'ArrayFileVariable',
]

View File

@ -1,2 +0,0 @@
class VariableError(Exception):
pass

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -15,9 +15,11 @@ from .segments import (
)
from .types import SegmentType
from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
@ -27,37 +29,39 @@ from .variables import (
)
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
if (value_type := m.get('value_type')) is None:
raise ValueError('missing value type')
if not m.get('name'):
raise ValueError('missing name')
if (value := m.get('value')) is None:
raise ValueError('missing value')
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
return StringVariable.model_validate(m)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
return SecretVariable.model_validate(m)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
return IntegerVariable.model_validate(m)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
return FloatVariable.model_validate(m)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
raise ValueError(f'invalid number value {value}')
case SegmentType.FILE:
return FileVariable.model_validate(m)
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
return ObjectVariable.model_validate(
{**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}}
)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
return result
return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
case SegmentType.ARRAY_FILE if isinstance(value, list):
return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
raise ValueError(f'not supported value type {value_type}')
def build_segment(value: Any, /) -> Segment:
@ -70,7 +74,12 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArrayAnySegment(value=elements)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -1,10 +1,11 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType
@ -36,10 +37,6 @@ class Segment(BaseModel):
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
@ -76,7 +73,14 @@ class IntegerSegment(Segment):
value: int
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class ObjectSegment(Segment):
@ -99,31 +103,32 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
value: Sequence[Segment]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str]
value: Sequence[StringSegment]
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int]
value: Sequence[FloatSegment | IntegerSegment]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
value: Sequence[ObjectSegment]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]

View File

@ -10,6 +10,8 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'
GROUP = 'group'

View File

@ -4,9 +4,11 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -42,6 +44,10 @@ class IntegerVariable(IntegerSegment, Variable):
pass
class FileVariable(FileSegment, Variable):
pass
class ObjectVariable(ObjectSegment, Variable):
pass
@ -62,6 +68,9 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariable, Any]

View File

@ -1,19 +1,14 @@
import enum
from typing import Any, Optional
from typing import Optional
from pydantic import BaseModel
from core.app.app_config.entities import FileExtraConfig
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
from models.model import UploadFile
class FileType(enum.Enum):
@ -119,7 +114,6 @@ class FileVar(BaseModel):
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url

View File

@ -5,7 +5,8 @@ from urllib.parse import parse_qs, urlparse
import requests
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile
@ -99,7 +100,7 @@ class MessageFileParser:
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
"""
transform message files
@ -144,7 +145,7 @@ class MessageFileParser:
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
"""
transform file to file obj

View File

@ -2,6 +2,7 @@ import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
@ -13,7 +14,6 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@ -271,8 +271,9 @@ class ModelInstance:
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param voice: model timbre
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@ -400,10 +401,6 @@ class LBModelManager:
managed_credentials: Optional[dict] = None) -> None:
"""
Load balancing model manager
:param tenant_id: tenant_id
:param provider: provider
:param model_type: model_type
:param model: model name
:param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__
"""

View File

@ -1,3 +1,4 @@
from core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
@ -93,16 +94,5 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
'required': False,
'options': ['JSON', 'XML'],
},
DefaultParameterName.JSON_SCHEMA: {
'label': {
'en_US': 'JSON Schema',
},
'type': 'text',
'help': {
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
'zh_Hans': '设置返回的json schemallm将按照它返回',
},
'required': False,
},
}
}
}

View File

@ -95,7 +95,6 @@ class DefaultParameterName(Enum):
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
@ -119,7 +118,6 @@ class ParameterType(Enum):
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
class ModelPropertyKey(Enum):

View File

@ -84,8 +84,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)

View File

@ -31,14 +31,6 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: Base URL, 如https://api.moonshot.cn/v1
en_US: Base URL, e.g. https://api.moonshot.cn/v1
model_credential_schema:
model:
label:

View File

@ -2,7 +2,6 @@
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-2024-08-06
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo

View File

@ -1,44 +0,0 @@
model: chatgpt-4o-latest
label:
zh_Hans: chatgpt-4o-latest
en_US: chatgpt-4o-latest
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -37,9 +37,6 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '2.50'
output: '10.00'

View File

@ -37,9 +37,6 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '0.15'
output: '0.60'

View File

@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -545,18 +544,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not currect json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
if response_format == "json_object":
response_format = {"type": "json_object"}
else:
model_parameters["response_format"] = {"type": response_format}
response_format = {"type": "text"}
model_parameters["response_format"] = response_format
extra_model_kwargs = {}
@ -928,14 +922,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model.startswith('ft:'):
model = model.split(':')[1]
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
if model == "chatgpt-4o-latest":
model = "gpt-4o"
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -955,7 +946,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://platform.openai.com/docs/advanced-usage/managing-tokens for "
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0

View File

@ -7,7 +7,6 @@ description:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- customizable-model
model_credential_schema:
@ -62,22 +61,6 @@ model_credential_schema:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: llm
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: text-embedding
type: text-input
default: '4096'
placeholder:

View File

@ -1,63 +0,0 @@
from typing import IO, Optional
from urllib.parse import urljoin
import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
headers = {}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")
payload = {"model": model}
files = [("file", file)]
response = requests.post(endpoint_url, headers=headers, data=payload, files=files)
if response.status_code != 200:
raise InvokeBadRequestError(response.text)
response_data = response.json()
return response_data["text"]
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, "rb") as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@ -1,61 +0,0 @@
model: Llama3-Chinese_v2
label:
en_US: Llama3-Chinese_v2
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: Meta-Llama-3-70B-Instruct-GPTQ-Int4
label:
en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 1024
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: Meta-Llama-3-8B-Instruct
label:
en_US: Meta-Llama-3-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
label:
en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 410960
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: Meta-Llama-3.1-8B-Instruct
label:
en_US: Meta-Llama-3.1-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.1
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -55,8 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB
deprecated: true

View File

@ -55,8 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB
deprecated: true

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 2048
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: completion
context_size: 32768
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -8,12 +8,12 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 2048
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.7
default: 0.3
min: 0.0
max: 2.0
help:
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -1,61 +0,0 @@
model: Qwen2-72B-Instruct
label:
en_US: Qwen2-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -8,7 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: completion
context_size: 32768
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -1,15 +1,6 @@
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
- Meta-Llama-3.1-8B-Instruct
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
- Meta-Llama-3-8B-Instruct
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-72B-Instruct
- Qwen2-7B
- Qwen-14B-Chat-Int4
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- deepseek-v2-chat
- deepseek-v2-lite-chat
- Llama3-Chinese_v2
- chatglm3-6b
- Qwen-14B-Chat-Int4

View File

@ -1,61 +0,0 @@
model: chatglm3-6b
label:
en_US: chatglm3-6b
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: deepseek-v2-chat
label:
en_US: deepseek-v2-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,61 +0,0 @@
model: deepseek-v2-lite-chat
label:
en_US: deepseek-v2-lite-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,4 +0,0 @@
model: BAAI/bge-large-en-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -1,4 +0,0 @@
model: BAAI/bge-large-zh-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -1,4 +0,0 @@
model: netease-youdao/bce-reranker-base_v1
model_type: rerank
model_properties:
context_size: 512

View File

@ -1,4 +0,0 @@
model: BAAI/bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192

View File

@ -1,87 +0,0 @@
from typing import Optional
import httpx
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class SiliconflowRerankModel(RerankModel):
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
if base_url.endswith('/'):
base_url = base_url[:-1]
try:
response = httpx.post(
base_url + '/rerank',
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n,
"return_documents": True
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

View File

@ -6,7 +6,6 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class SiliconflowProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:

View File

@ -12,12 +12,10 @@ help:
en_US: Get your API Key from SiliconFlow
zh_Hans: 从 SiliconFlow 获取 API Key
url:
en_US: https://cloud.siliconflow.cn/account/ak
en_US: https://cloud.siliconflow.cn/keys
supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -1,5 +0,0 @@
model: iic/SenseVoiceSmall
model_type: speech2text
model_properties:
file_upload_limit: 1
supported_file_extensions: mp3,wav

View File

@ -1,32 +0,0 @@
from typing import IO, Optional
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for Siliconflow Speech to text model.
"""
def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
return super().validate_credentials(model, credentials)
@classmethod
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"

View File

@ -1,5 +0,0 @@
model: netease-youdao/bce-embedding-base_v1
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 1

View File

@ -1,5 +0,0 @@
model: BAAI/bge-m3
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1

View File

@ -1,81 +0,0 @@
model: farui-plus
label:
en_US: farui-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 12288
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: enable_search
type: boolean
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
use_template: response_format
pricing:
input: '0.02'
output: '0.02'
unit: '0.001'
currency: RMB

View File

@ -159,8 +159,6 @@ You should also complete the text started with ``` but not tell ``` directly.
"""
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
model = model.replace('-chat', '')
if model == 'farui-plus':
model = 'qwen-farui-plus'
if model in self.tokenizers:
tokenizer = self.tokenizers[model]

View File

@ -2,8 +2,3 @@ model: text-embedding-v1
model_type: text-embedding
model_properties:
context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,8 +2,3 @@ model: text-embedding-v2
model_type: text-embedding
model_properties:
context_size: 2048
max_chunks: 25
pricing:
input: "0.0007"
unit: "0.001"
currency: RMB

View File

@ -2,7 +2,6 @@ import time
from typing import Optional
import dashscope
import numpy as np
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
@ -22,11 +21,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -38,44 +37,16 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=inputs[i : i + max_chunks],
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=used_tokens
embeddings, embedding_used_tokens = self.embed_documents(
credentials_kwargs=credentials_kwargs,
model=model,
texts=texts
)
return TextEmbeddingResult(
embeddings=batched_embeddings, usage=usage, model=model
embeddings=embeddings,
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@ -108,16 +79,12 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# call embedding model
self.embed_documents(
credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]
)
self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@staticmethod
def embed_documents(
credentials_kwargs: dict, model: str, texts: list[str]
) -> tuple[list[list[float]], int]:
def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]:
"""Call out to Tongyi's embedding endpoint.
Args:
@ -135,7 +102,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
api_key=credentials_kwargs["dashscope_api_key"],
model=model,
input=text,
text_type="document",
text_type="document"
)
data = response.output["embeddings"][0]
embeddings.append(data["embedding"])
@ -144,7 +111,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
"""
Calculate response usage
@ -158,7 +125,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens,
tokens=tokens
)
# transform usage
@ -169,7 +136,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -1 +1 @@
- solar-1-mini-chat
- soloar-1-mini-chat

View File

@ -35,10 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -98,12 +95,37 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
req_params = get_v2_req_params(credentials, model_parameters, stop)
req_params = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).copy()
if credentials.get('context_size'):
req_params['max_prompt_tokens'] = credentials.get('context_size')
if credentials.get('max_tokens'):
req_params['max_new_tokens'] = credentials.get('max_tokens')
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
@ -175,8 +197,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
"""
used to define customizable model schema
"""
model_config = get_model_config(credentials)
max_tokens = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
if credentials.get('max_tokens'):
max_tokens = int(credentials.get('max_tokens'))
rules = [
ParameterRule(
name='temperature',
@ -210,10 +234,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='presence_penalty',
type=ParameterType.FLOAT,
use_template='presence_penalty',
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
),
label={
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
},
min=-2.0,
max=2.0,
),
@ -221,10 +245,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='frequency_penalty',
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
),
label={
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
},
min=-2.0,
max=2.0,
),
@ -233,7 +257,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=model_config.properties.max_tokens,
max=max_tokens,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
@ -242,10 +266,17 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
),
]
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('mode'):
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
entity = AIModelEntity(
model=model,
label=I18nObject(
@ -255,7 +286,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules,
features=model_config.features,
features=model_features,
)
return entity

View File

@ -1,123 +1,181 @@
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
configs: dict[str, ModelConfig] = {
'Doubao-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Llama3-70B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-8k': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B-Fin': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Mistral-7B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[]
)
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4000,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-8k': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 16384,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 65536,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B-Fin': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Mistral-7B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 2048,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
}
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_prompt_tokens'] = model_configs.properties.context_size
req_params['max_new_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -1,27 +1,9 @@
from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
),
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs

View File

@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -115,10 +115,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
"""
generate custom model entities from credentials
"""
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),

View File

@ -1,195 +0,0 @@
from datetime import datetime, timedelta
from threading import Lock
from requests import post
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

Some files were not shown because too many files have changed in this diff Show More