mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
Compare commits
139 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| de57af46c0 | |||
| badf9baf9b | |||
| adcd83f6a8 | |||
| 81d4d8cea1 | |||
| 4da0b70694 | |||
| 7056009b6a | |||
| ddb960ddfb | |||
| 0ebd985672 | |||
| c13dc62065 | |||
| 705946cc40 | |||
| aafa4a3c8b | |||
| af68084895 | |||
| 9633c5dab6 | |||
| aa11141660 | |||
| 8bb5b943d7 | |||
| 22776f24ab | |||
| 216442ddc1 | |||
| dd3ac7a2c9 | |||
| 11447324ff | |||
| f8210b353e | |||
| 2b66c1358b | |||
| 102d86d4b6 | |||
| 227f49a0cc | |||
| a17f169e01 | |||
| 72ea3d6b98 | |||
| 17cacf258e | |||
| f7aacefcd6 | |||
| ace7ffab5f | |||
| eec63b112f | |||
| caf7bc8569 | |||
| fd437ff4c5 | |||
| fb218f8b10 | |||
| 4693080ce0 | |||
| 60ddcdf960 | |||
| 303bafb3ac | |||
| 7a0d0d9b96 | |||
| 84a9d2d072 | |||
| 1b5adf40da | |||
| 59a32aaae6 | |||
| 18106a4fc6 | |||
| fc2297a2ca | |||
| 5b7b765090 | |||
| 90769ac709 | |||
| ac9f1e9de5 | |||
| 5bf31e7a86 | |||
| dd17506078 | |||
| 5d1424f67c | |||
| 2346b0ab99 | |||
| 88dec6ef2b | |||
| e71f494839 | |||
| 22bb0414a1 | |||
| 6477bb8d77 | |||
| 70ddc0ce43 | |||
| 9986e4c6d0 | |||
| e2710161f6 | |||
| 5f11fe521d | |||
| d018b32d0b | |||
| e54b7cda3d | |||
| fc63841169 | |||
| b674c598f9 | |||
| 710230a294 | |||
| 169f7440ac | |||
| 57ec12eb6b | |||
| 2c26f77a25 | |||
| 95dc90e6b2 | |||
| 400392230b | |||
| eca66f9577 | |||
| 121bb99cc2 | |||
| cac1ef7ade | |||
| d74d79b3d8 | |||
| c6b28bc193 | |||
| 5d05574518 | |||
| bf478aeba2 | |||
| c9dfe1ad92 | |||
| 926609eb59 | |||
| e32116b9a3 | |||
| e11d5ac708 | |||
| f6c3d4cadc | |||
| 3e9d271b52 | |||
| ecc8beef3f | |||
| b9afb7bcec | |||
| b4041759f7 | |||
| c3473b5b4f | |||
| 1b9bf9c62d | |||
| ed96a6b6c0 | |||
| 4989d0c904 | |||
| 9a5bdae07f | |||
| 67016feb96 | |||
| 22bdfb7e56 | |||
| ceb2c4f3ef | |||
| d5a93a6400 | |||
| 01a2513812 | |||
| 8e7a752b2a | |||
| 999d3f1539 | |||
| a7ee51e5d8 | |||
| 0e965b6529 | |||
| a9db06f5e7 | |||
| 6827c4038b | |||
| e8a6e90a61 | |||
| ff956cb546 | |||
| 7d7e0f9800 | |||
| 3ae05a672d | |||
| d700abff0a | |||
| 5267f34e76 | |||
| d6e8290a1c | |||
| 36f66d40e5 | |||
| 5f12616cb9 | |||
| bc43efba75 | |||
| ef5f476cd6 | |||
| 98bf7710e4 | |||
| 7263af13ed | |||
| d992a809f5 | |||
| 04f8d39860 | |||
| b7bf14ab72 | |||
| e8abbe0623 | |||
| b14d59e977 | |||
| 5f12c17355 | |||
| d170d78530 | |||
| 4d9160ca9f | |||
| 8f670f31b8 | |||
| 5838345f48 | |||
| 3f1c84f65a | |||
| 83b2b8fe60 | |||
| ac24300274 | |||
| 2e657b7b12 | |||
| c063617553 | |||
| 38a4f0234d | |||
| 740a723072 | |||
| 495cf58014 | |||
| 8e98759359 | |||
| 4ae0bb83f1 | |||
| 5459d812e7 | |||
| 831c222541 | |||
| faad247d85 | |||
| 1e829ceaf3 | |||
| 79fe175440 | |||
| 9b32bfb3db | |||
| 37fea072bc | |||
| 31a603e905 |
@ -1,5 +1,9 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
@ -168,7 +172,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4GB
|
||||
>- RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。
|
||||
|
||||
- **自托管 Dify 社区版</br>**
|
||||
使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。
|
||||
使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。
|
||||
使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。
|
||||
|
||||
- **面向企业/组织的 Dify</br>**
|
||||
@ -174,7 +174,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
在安装 Dify 之前,请确保您的机器满足以下最低系统要求:
|
||||
|
||||
- CPU >= 2 Core
|
||||
- RAM >= 4GB
|
||||
- RAM >= 4 GiB
|
||||
|
||||
### 快速启动
|
||||
|
||||
|
||||
@ -31,8 +31,17 @@ REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
REDIS_SENTINELS=
|
||||
REDIS_SENTINEL_SERVICE_NAME=
|
||||
REDIS_SENTINEL_USERNAME=
|
||||
REDIS_SENTINEL_PASSWORD=
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
@ -111,7 +120,7 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb, upstash
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -220,6 +229,10 @@ BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
UPSTASH_VECTOR_TOKEN=your-access-token
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
@ -239,6 +252,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
@ -304,6 +318,10 @@ RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
# Log file max size, the unit is MB
|
||||
LOG_FILE_MAX_SIZE=20
|
||||
# Log file max backup count
|
||||
LOG_FILE_BACKUP_COUNT=5
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
|
||||
@ -55,7 +55,9 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
16
api/app.py
16
api/app.py
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
@ -36,17 +38,11 @@ if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
|
||||
# -------------
|
||||
# Configuration
|
||||
# -------------
|
||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if app.config.get("TESTING"):
|
||||
if dify_config.TESTING:
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@ -54,15 +50,15 @@ if app.config.get("TESTING"):
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie("remember_token", "", expires=0)
|
||||
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@ -10,9 +10,6 @@ if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
@ -27,6 +24,7 @@ from extensions import (
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
@ -70,43 +68,7 @@ def create_flask_app_with_configs() -> Flask:
|
||||
|
||||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
|
||||
app.secret_key = app.config["SECRET_KEY"]
|
||||
|
||||
log_handlers = None
|
||||
log_file = app.config.get("LOG_FILE")
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers = [
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
log_tz = app.config.get("LOG_TZ")
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
handler.formatter.converter = time_converter
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
@ -117,6 +79,7 @@ def create_app() -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_logging.init_app(app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
@ -187,7 +150,7 @@ def register_blueprints(app):
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
@ -198,7 +161,7 @@ def register_blueprints(app):
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
|
||||
@ -277,6 +277,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.TENCENT,
|
||||
VectorType.BAIDU,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
page = 1
|
||||
while True:
|
||||
|
||||
@ -32,6 +32,21 @@ class SecurityConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY_ENABLE: bool = Field(
|
||||
description="Whether to enable admin api key for authentication",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY: Optional[str] = Field(
|
||||
description="admin api key for authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
@ -304,6 +319,16 @@ class LoggingConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FILE_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
|
||||
default=20,
|
||||
)
|
||||
|
||||
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
|
||||
description="Maximum file backup count file rotation retention",
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
@ -546,6 +571,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
|
||||
description="number of tidb serverless cluster",
|
||||
default=500,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -27,7 +27,9 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
@ -53,6 +55,11 @@ class VectorStoreConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
|
||||
description="Enable whitelist for vector store.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
@ -246,5 +253,7 @@ class MiddlewareConfig(
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
65
api/configs/middleware/vdb/tidb_on_qdrant_config.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TidbOnQdrantConfig(BaseSettings):
|
||||
"""
|
||||
Tidb on Qdrant configs
|
||||
"""
|
||||
|
||||
TIDB_ON_QDRANT_URL: Optional[str] = Field(
|
||||
description="Tidb on Qdrant url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
|
||||
description="Tidb on Qdrant api key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Tidb on Qdrant client timeout in seconds",
|
||||
default=20,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description="whether enable grpc support for Tidb on Qdrant connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description="Tidb on Qdrant grpc port",
|
||||
default=6334,
|
||||
)
|
||||
|
||||
TIDB_PUBLIC_KEY: Optional[str] = Field(
|
||||
description="Tidb account public key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_PRIVATE_KEY: Optional[str] = Field(
|
||||
description="Tidb account private key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_API_URL: Optional[str] = Field(
|
||||
description="Tidb API url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_IAM_API_URL: Optional[str] = Field(
|
||||
description="Tidb IAM API url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_REGION: Optional[str] = Field(
|
||||
description="Tidb serverless region",
|
||||
default="regions/aws-us-east-1",
|
||||
)
|
||||
|
||||
TIDB_PROJECT_ID: Optional[str] = Field(
|
||||
description="Tidb project id",
|
||||
default=None,
|
||||
)
|
||||
20
api/configs/middleware/vdb/upstash_config.py
Normal file
20
api/configs/middleware/vdb/upstash_config.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class UpstashConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Upstash vector database
|
||||
"""
|
||||
|
||||
UPSTASH_VECTOR_URL: Optional[str] = Field(
|
||||
description="URL of the upstash server (e.g., 'https://vector.upstash.io')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UPSTASH_VECTOR_TOKEN: Optional[str] = Field(
|
||||
description="Token for authenticating with the upstash server",
|
||||
default=None,
|
||||
)
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.10.0",
|
||||
default="0.10.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -15,7 +15,9 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not os.getenv("ADMIN_API_KEY"):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -52,4 +52,39 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
max_tokens=CODE_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return code_result
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
|
||||
@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import redirect, request
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
@ -196,10 +195,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
return NotAllowedCreateWorkspace()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
@ -94,17 +94,15 @@ class OAuthCallback(Resource):
|
||||
account = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except WorkSpaceNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
|
||||
# Check account status
|
||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
|
||||
@ -102,6 +102,13 @@ class DatasetListApi(Resource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@ -140,6 +147,7 @@ class DatasetListApi(Resource):
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
@ -619,6 +627,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -630,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@ -657,6 +667,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -30,13 +30,12 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
def get(self):
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
|
||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
"file_size_limit": file_size_limit,
|
||||
"batch_count_limit": batch_count_limit,
|
||||
"image_file_size_limit": image_file_size_limit,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -41,7 +41,7 @@ class AlreadyActivateError(BaseHTTPException):
|
||||
|
||||
|
||||
class NotAllowedCreateWorkspace(BaseHTTPException):
|
||||
error_code = "unauthorized"
|
||||
error_code = "not_allowed_create_workspace"
|
||||
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
code = 400
|
||||
|
||||
|
||||
@ -21,7 +21,12 @@ class AppParameterApi(InstalledAppResource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@ -82,7 +87,12 @@ class AppParameterApi(InstalledAppResource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
@ -41,24 +41,39 @@ class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
||||
parser.add_argument("sign", type=str, required=True, location="args")
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_signed_file_preview(
|
||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
|
||||
@ -42,10 +42,10 @@ class ToolFilePreviewApi(Resource):
|
||||
stream,
|
||||
mimetype=tool_file.mimetype,
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Length": str(tool_file.size),
|
||||
},
|
||||
headers={},
|
||||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class EnterpriseWorkspace(Resource):
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args["name"])
|
||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@ -21,7 +21,12 @@ class AppParameterApi(Resource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@ -81,7 +86,12 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class MessageListApi(Resource):
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.String),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
|
||||
@ -66,6 +66,13 @@ class DatasetListApi(DatasetApiResource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@ -108,6 +115,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
|
||||
@ -21,7 +21,12 @@ class AppParameterApi(WebApiResource):
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
@ -80,7 +85,12 @@ class AppParameterApi(WebApiResource):
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
@ -165,6 +165,12 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
@ -250,6 +256,12 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
@ -53,11 +53,11 @@ class BasicVariablesConfigManager:
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description", ""),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options", []),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -114,6 +114,16 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -17,10 +17,13 @@ class FileUploadConfigManager:
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
@ -240,7 +241,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
|
||||
@ -53,7 +53,7 @@ class BasedGenerateTaskPipeline:
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
@ -100,7 +100,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
|
||||
def _error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
|
||||
@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
@ -232,30 +234,30 @@ class WorkflowCycleManage:
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
session.commit()
|
||||
session.refresh(workflow_node_execution)
|
||||
|
||||
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
|
||||
@ -76,8 +76,16 @@ def to_prompt_message_content(f: File, /):
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask import Config, Flask
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
@ -44,32 +45,30 @@ class HostingConfiguration:
|
||||
moderation_config: HostedModerationConfig = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
|
||||
if config.get("EDITION") != "CLOUD":
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai(config)
|
||||
self.provider_map["openai"] = self.init_openai(config)
|
||||
self.provider_map["anthropic"] = self.init_anthropic(config)
|
||||
self.provider_map["minimax"] = self.init_minimax(config)
|
||||
self.provider_map["spark"] = self.init_spark(config)
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai(config)
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
|
||||
self.moderation_config = self.init_moderation_config(config)
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@staticmethod
|
||||
def init_azure_openai(app_config: Config) -> HostingProvider:
|
||||
def init_azure_openai() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
|
||||
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
|
||||
"base_model_name": "gpt-35-turbo",
|
||||
}
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
@ -122,31 +121,31 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_openai(self, app_config: Config) -> HostingProvider:
|
||||
def init_openai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
|
||||
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
|
||||
if dify_config.HOSTED_OPENAI_API_BASE:
|
||||
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
|
||||
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
@ -156,26 +155,26 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic(app_config: Config) -> HostingProvider:
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
if dify_config.HOSTED_ANTHROPIC_API_BASE:
|
||||
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
@ -185,9 +184,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax(app_config: Config) -> HostingProvider:
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_MINIMAX_ENABLED"):
|
||||
if dify_config.HOSTED_MINIMAX_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -203,9 +202,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_spark(app_config: Config) -> HostingProvider:
|
||||
def init_spark() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_SPARK_ENABLED"):
|
||||
if dify_config.HOSTED_SPARK_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -221,9 +220,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_zhipuai(app_config: Config) -> HostingProvider:
|
||||
def init_zhipuai() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
|
||||
if dify_config.HOSTED_ZHIPUAI_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -239,17 +238,15 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_moderation_config(app_config: Config) -> HostedModerationConfig:
|
||||
if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
return HostedModerationConfig(
|
||||
enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",")
|
||||
)
|
||||
def init_moderation_config() -> HostedModerationConfig:
|
||||
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
|
||||
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
|
||||
|
||||
return HostedModerationConfig(enabled=False)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
|
||||
models_str = dify_config.model_dump().get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [
|
||||
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)
|
||||
|
||||
@ -8,6 +8,8 @@ from core.llm_generator.output_parser.suggested_questions_after_answer import Su
|
||||
from core.llm_generator.prompts import (
|
||||
CONVERSATION_TITLE_PROMPT,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
@ -239,6 +241,54 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
code_language: str = "javascript",
|
||||
max_tokens: int = 1000,
|
||||
) -> dict:
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
"CODE_LANGUAGE": code_language,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = response.message.content
|
||||
return {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
@ -61,6 +61,73 @@ User Input: yo, 你今天咋样?
|
||||
User Input:
|
||||
""" # noqa: E501
|
||||
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
"You are an expert programmer. Generate code based on the following instructions:\n\n"
|
||||
"Instructions: {{INSTRUCTION}}\n\n"
|
||||
"Write the code in {{CODE_LANGUAGE}}.\n\n"
|
||||
"Please ensure that you meet the following requirements:\n"
|
||||
"1. Define a function named 'main'.\n"
|
||||
"2. The 'main' function must return a dictionary (dict).\n"
|
||||
"3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n"
|
||||
"4. The returned dictionary should contain at least one key-value pair.\n\n"
|
||||
"5. You may ONLY use the following libraries in your code: \n"
|
||||
"- json\n"
|
||||
"- datetime\n"
|
||||
"- math\n"
|
||||
"- random\n"
|
||||
"- re\n"
|
||||
"- string\n"
|
||||
"- sys\n"
|
||||
"- time\n"
|
||||
"- traceback\n"
|
||||
"- uuid\n"
|
||||
"- os\n"
|
||||
"- base64\n"
|
||||
"- hashlib\n"
|
||||
"- hmac\n"
|
||||
"- binascii\n"
|
||||
"- collections\n"
|
||||
"- functools\n"
|
||||
"- operator\n"
|
||||
"- itertools\n\n"
|
||||
"Example:\n"
|
||||
"def main(arg1: str, arg2: int) -> dict:\n"
|
||||
" return {\n"
|
||||
' "result": arg1 * arg2,\n'
|
||||
" }\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
|
||||
"- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n"
|
||||
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
|
||||
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
|
||||
"- Always use the format return {'result': ...} for the output.\n\n"
|
||||
"Generated Code:\n"
|
||||
)
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
"You are an expert programmer. Generate code based on the following instructions:\n\n"
|
||||
"Instructions: {{INSTRUCTION}}\n\n"
|
||||
"Write the code in {{CODE_LANGUAGE}}.\n\n"
|
||||
"Please ensure that you meet the following requirements:\n"
|
||||
"1. Define a function named 'main'.\n"
|
||||
"2. The 'main' function must return an object.\n"
|
||||
"3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n"
|
||||
"4. The returned object should contain at least one key-value pair.\n\n"
|
||||
"5. The returned object should always be in the format: {result: ...}\n\n"
|
||||
"Example:\n"
|
||||
"function main(arg1, arg2) {\n"
|
||||
" return {\n"
|
||||
" result: arg1 * arg2\n"
|
||||
" };\n"
|
||||
"}\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
|
||||
"- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n"
|
||||
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
|
||||
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
|
||||
"Generated Code:\n"
|
||||
)
|
||||
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.file.models import FileType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -98,8 +99,9 @@ class TokenBufferMemory:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@ -218,7 +218,7 @@ For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` param
|
||||
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -205,7 +205,7 @@ provider_credential_schema:
|
||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例A模型支持`top_k`,B模型不支持`top_k`,那么我们需要在这里动态生成模型参数的Schema,如下所示:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -105,6 +105,7 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
- claude-3-opus-20240229
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
model: claude-3-5-sonnet-20241022
|
||||
label:
|
||||
en_US: claude-3-5-sonnet-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -294,7 +294,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -148,7 +148,7 @@ class AzureRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -53,6 +53,9 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-10-01-preview
|
||||
value: 2024-10-01-preview
|
||||
- label:
|
||||
en_US: 2024-09-01-preview
|
||||
value: 2024-09-01-preview
|
||||
|
||||
@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
if not model_entity:
|
||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
||||
@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
# extra_model_kwargs['functions'] = [{
|
||||
# "name": tool.name,
|
||||
# "description": tool.description,
|
||||
# "parameters": tool.parameters
|
||||
# } for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
def _get_base_model_name(self, credentials: dict) -> str:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
return base_model_name
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
model: anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: us.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 9.8 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M25.132 24.3947C25.497 25.7527 25.8984 27.1413 26.3334 28.5834C26.7302 29.8992 25.5459 30.4167 25.0752 29.1758C24.571 27.8466 24.0885 26.523 23.6347 25.1729C21.065 26.4654 18.5025 27.5424 15.5961 28.7541C16.7581 33.0256 17.8309 36.5984 19.4952 39.9935C19.4953 39.9936 19.4953 39.9937 19.4954 39.9938C19.6631 39.9979 19.8313 40 20 40C31.0457 40 40 31.0457 40 20C40 16.0335 38.8453 12.3366 36.8537 9.22729C31.6585 9.69534 27.0513 10.4562 22.8185 11.406C22.8882 12.252 22.9677 13.0739 23.0555 13.855C23.3824 16.7604 23.9112 19.5281 24.6137 22.3836C27.0581 21.2848 29.084 20.3225 30.6816 19.522C32.2154 18.7535 33.6943 18.7062 31.2018 20.6594C29.0388 22.1602 27.0644 23.3566 25.132 24.3947ZM36.1559 8.20846C33.0001 3.89184 28.1561 0.887462 22.5955 0.166882C22.4257 2.86234 22.4785 6.26344 22.681 9.50447C26.7473 8.88859 31.1721 8.46032 36.1559 8.20846ZM19.9369 9.73661e-05C19.7594 2.92694 19.8384 6.65663 20.19 9.91293C17.3748 10.4109 14.7225 11.0064 12.1592 11.7038C12.0486 10.4257 11.9927 9.25764 11.9927 8.24178C11.9927 7.5054 11.3957 6.90844 10.6593 6.90844C9.92296 6.90844 9.32601 7.5054 9.32601 8.24178C9.32601 9.47868 9.42873 10.898 9.61402 12.438C8.33567 12.8278 7.07397 13.2443 5.81918 13.688C5.12493 13.9336 4.76118 14.6954 5.0067 15.3896C5.25223 16.0839 6.01406 16.4476 6.7083 16.2021C7.7931 15.8185 8.88482 15.4388 9.98927 15.0659C10.5222 18.3344 11.3344 21.9428 12.2703 25.4156C12.4336 26.0218 12.6062 26.6262 12.7863 27.2263C9.34168 28.4135 5.82612 29.3782 2.61128 29.8879C0.949407 26.9716 0 23.5967 0 20C0 8.97534 8.92023 0.0341108 19.9369 9.73661e-05ZM4.19152 32.2527C7.45069 36.4516 12.3458 39.3173 17.9204 39.8932C16.5916 37.455 14.9338 33.717 13.5405 29.5901C10.4404 30.7762 7.25883 31.6027 4.19152 32.2527ZM22.9735 23.1135C22.1479 20.41 21.4462 17.5441 20.9225 14.277C20.746 13.5841 20.5918 12.8035 20.4593 11.9636C17.6508 12.6606 14.9992 13.4372 12.4356 14.2598C12.8479 17.4766 13.5448 21.1334 14.5118 24.7218C14.662 25.2792 14.8081 25.8248 14.9514 26.3594L14.9516 26.3603L14.9524 26.3634L14.9526 26.3639L14.973 26.4401C16.1833 25.9872 17.3746 25.5123 18.53 25.0259C20.1235 24.3552 21.6051 23.7165 22.9735 23.1135Z" fill="#141519"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
47
api/core/model_runtime/model_providers/gitee_ai/_common.py
Normal file
47
api/core/model_runtime/model_providers/gitee_ai/_common.py
Normal file
@ -0,0 +1,47 @@
|
||||
from dashscope.common.error import (
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
ServiceUnavailableError,
|
||||
UnsupportedHTTPMethod,
|
||||
UnsupportedModel,
|
||||
)
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonGiteeAI:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
RequestFailure,
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
ServiceUnavailableError,
|
||||
],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [
|
||||
AuthenticationError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvalidParameter,
|
||||
UnsupportedModel,
|
||||
UnsupportedHTTPMethod,
|
||||
],
|
||||
}
|
||||
25
api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py
Normal file
25
api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py
Normal file
@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteeAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
@ -0,0 +1,35 @@
|
||||
provider: gitee_ai
|
||||
label:
|
||||
en_US: Gitee AI
|
||||
zh_Hans: Gitee AI
|
||||
description:
|
||||
en_US: 快速体验大模型,领先探索 AI 开源世界
|
||||
zh_Hans: 快速体验大模型,领先探索 AI 开源世界
|
||||
icon_small:
|
||||
en_US: Gitee-AI-Logo.svg
|
||||
icon_large:
|
||||
en_US: Gitee-AI-Logo-full.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from Gitee AI
|
||||
zh_Hans: 从 Gitee AI 获取 token
|
||||
url:
|
||||
en_US: https://ai.gitee.com/dashboard/settings/tokens
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -0,0 +1,105 @@
|
||||
model: Qwen2-72B-Instruct
|
||||
label:
|
||||
zh_Hans: Qwen2-72B-Instruct
|
||||
en_US: Qwen2-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 6400
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: Qwen2-7B-Instruct
|
||||
label:
|
||||
zh_Hans: Qwen2-7B-Instruct
|
||||
en_US: Qwen2-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: Yi-1.5-34B-Chat
|
||||
label:
|
||||
zh_Hans: Yi-1.5-34B-Chat
|
||||
en_US: Yi-1.5-34B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,7 @@
|
||||
- Qwen2-7B-Instruct
|
||||
- Qwen2-72B-Instruct
|
||||
- Yi-1.5-34B-Chat
|
||||
- glm-4-9b-chat
|
||||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
@ -0,0 +1,105 @@
|
||||
model: codegeex4-all-9b
|
||||
label:
|
||||
zh_Hans: codegeex4-all-9b
|
||||
en_US: codegeex4-all-9b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 40960
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: deepseek-coder-33B-instruct-chat
|
||||
label:
|
||||
zh_Hans: deepseek-coder-33B-instruct-chat
|
||||
en_US: deepseek-coder-33B-instruct-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 9000
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,91 @@
|
||||
model: deepseek-coder-33B-instruct-completions
|
||||
label:
|
||||
zh_Hans: deepseek-coder-33B-instruct-completions
|
||||
en_US: deepseek-coder-33B-instruct-completions
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 9000
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: glm-4-9b-chat
|
||||
label:
|
||||
zh_Hans: glm-4-9b-chat
|
||||
en_US: glm-4-9b-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
47
api/core/model_runtime/model_providers/gitee_ai/llm/llm.py
Normal file
47
api/core/model_runtime/model_providers/gitee_ai/llm/llm.py
Normal file
@ -0,0 +1,47 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_TO_IDENTITY: dict[str, str] = {
|
||||
"Yi-1.5-34B-Chat": "Yi-34B-Chat",
|
||||
"deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct",
|
||||
"deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct",
|
||||
}
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None:
|
||||
if model is None:
|
||||
model = "bge-large-zh-v1.5"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
@ -0,0 +1 @@
|
||||
- bge-reranker-v2-m3
|
||||
@ -0,0 +1,4 @@
|
||||
model: bge-reranker-v2-m3
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
128
api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
Normal file
128
api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
Normal file
@ -0,0 +1,128 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class GiteeAIRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless")
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
body = {"model": model, "query": query, "documents": docs}
|
||||
if top_n is not None:
|
||||
body["top_n"] = top_n
|
||||
response = httpx.post(
|
||||
f"{base_url}/{model}/rerank",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["results"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result["index"],
|
||||
text=result["document"]["text"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.01,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,2 @@
|
||||
- whisper-base
|
||||
- whisper-large
|
||||
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from typing import IO, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
|
||||
class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Compatible Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text
|
||||
|
||||
endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text"
|
||||
files = [("file", file)]
|
||||
_, file_ext = os.path.splitext(file.name)
|
||||
headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
response = requests.post(endpoint_url, headers=headers, files=files)
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
response_data = response.json()
|
||||
return response_data["text"]
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
self._invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -0,0 +1,5 @@
|
||||
model: whisper-base
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||
@ -0,0 +1,5 @@
|
||||
model: whisper-large
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||
@ -0,0 +1,3 @@
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
- bge-m3
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-large-zh-v1.5
|
||||
label:
|
||||
zh_Hans: bge-large-zh-v1.5
|
||||
en_US: bge-large-zh-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-m3
|
||||
label:
|
||||
zh_Hans: bge-m3
|
||||
en_US: bge-m3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-small-zh-v1.5
|
||||
label:
|
||||
zh_Hans: bge-small-zh-v1.5
|
||||
en_US: bge-small-zh-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str) -> None:
|
||||
if model is None:
|
||||
model = "bge-m3"
|
||||
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"
|
||||
@ -0,0 +1,11 @@
|
||||
model: ChatTTS
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,11 @@
|
||||
model: FunAudioLLM-CosyVoice-300M
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,4 @@
|
||||
- speecht5_tts
|
||||
- ChatTTS
|
||||
- fish-speech-1.2-sft
|
||||
- FunAudioLLM-CosyVoice-300M
|
||||
@ -0,0 +1,11 @@
|
||||
model: fish-speech-1.2-sft
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,11 @@
|
||||
model: speecht5_tts
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
79
api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Normal file
79
api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
|
||||
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text="Hello Dify!",
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech
|
||||
endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload = {"inputs": content_text}
|
||||
response = requests.post(endpoint_url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
|
||||
data = response.content
|
||||
|
||||
for i in range(0, len(data), 1024):
|
||||
yield data[i : i + 1024]
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
@ -118,7 +118,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -189,7 +189,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
@ -212,7 +212,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
completion_model = None
|
||||
if credentials["completion_type"] == "chat_completion":
|
||||
completion_model = LLMMode.CHAT.value
|
||||
|
||||
@ -73,7 +73,7 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -115,7 +115,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@ -44,13 +44,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@ -61,7 +61,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
||||
|
||||
return response.text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -397,16 +397,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
@ -450,7 +455,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
@ -462,20 +467,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
@ -524,6 +535,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@ -536,6 +548,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@ -545,17 +558,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
response_json = response.json()
|
||||
response_json: dict = response.json()
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
output = response_json["choices"][0]
|
||||
message_id = response_json.get("id")
|
||||
|
||||
response_content = ""
|
||||
tool_calls = None
|
||||
@ -593,6 +611,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
id=message_id,
|
||||
model=response_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
|
||||
@ -62,7 +62,7 @@ class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
@ -193,7 +194,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -408,7 +408,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user