mirror of
https://github.com/langgenius/dify.git
synced 2026-02-14 23:35:31 +08:00
Compare commits
74 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 400392230b | |||
| eca66f9577 | |||
| 121bb99cc2 | |||
| cac1ef7ade | |||
| d74d79b3d8 | |||
| c6b28bc193 | |||
| 5d05574518 | |||
| bf478aeba2 | |||
| c9dfe1ad92 | |||
| 926609eb59 | |||
| e32116b9a3 | |||
| e11d5ac708 | |||
| f6c3d4cadc | |||
| 3e9d271b52 | |||
| ecc8beef3f | |||
| b9afb7bcec | |||
| b4041759f7 | |||
| c3473b5b4f | |||
| 1b9bf9c62d | |||
| ed96a6b6c0 | |||
| 4989d0c904 | |||
| 9a5bdae07f | |||
| 67016feb96 | |||
| 22bdfb7e56 | |||
| ceb2c4f3ef | |||
| d5a93a6400 | |||
| 01a2513812 | |||
| 8e7a752b2a | |||
| 999d3f1539 | |||
| a7ee51e5d8 | |||
| 0e965b6529 | |||
| a9db06f5e7 | |||
| 6827c4038b | |||
| e8a6e90a61 | |||
| ff956cb546 | |||
| 7d7e0f9800 | |||
| 3ae05a672d | |||
| d700abff0a | |||
| 5267f34e76 | |||
| d6e8290a1c | |||
| 36f66d40e5 | |||
| 5f12616cb9 | |||
| bc43efba75 | |||
| ef5f476cd6 | |||
| 98bf7710e4 | |||
| 7263af13ed | |||
| d992a809f5 | |||
| 04f8d39860 | |||
| b7bf14ab72 | |||
| e8abbe0623 | |||
| b14d59e977 | |||
| 5f12c17355 | |||
| d170d78530 | |||
| 4d9160ca9f | |||
| 8f670f31b8 | |||
| 5838345f48 | |||
| 3f1c84f65a | |||
| 83b2b8fe60 | |||
| ac24300274 | |||
| 2e657b7b12 | |||
| c063617553 | |||
| 38a4f0234d | |||
| 740a723072 | |||
| 495cf58014 | |||
| 8e98759359 | |||
| 4ae0bb83f1 | |||
| 5459d812e7 | |||
| 831c222541 | |||
| faad247d85 | |||
| 1e829ceaf3 | |||
| 79fe175440 | |||
| 9b32bfb3db | |||
| 37fea072bc | |||
| 31a603e905 |
@ -168,7 +168,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4GB
|
||||
>- RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
|
||||
@ -174,7 +174,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
在安装 Dify 之前,请确保您的机器满足以下最低系统要求:
|
||||
|
||||
- CPU >= 2 Core
|
||||
- RAM >= 4GB
|
||||
- RAM >= 4 GiB
|
||||
|
||||
### 快速启动
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb, upstash
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -220,6 +220,10 @@ BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
UPSTASH_VECTOR_TOKEN=your-access-token
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
@ -239,6 +243,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
@ -304,6 +309,10 @@ RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
# Log file max size, the unit is MB
|
||||
LOG_FILE_MAX_SIZE=20
|
||||
# Log file max backup count
|
||||
LOG_FILE_BACKUP_COUNT=5
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
|
||||
16
api/app.py
16
api/app.py
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
@ -36,17 +38,11 @@ if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
|
||||
# -------------
|
||||
# Configuration
|
||||
# -------------
|
||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if app.config.get("TESTING"):
|
||||
if dify_config.TESTING:
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@ -54,15 +50,15 @@ if app.config.get("TESTING"):
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie("remember_token", "", expires=0)
|
||||
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@ -10,9 +10,6 @@ if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
@ -27,6 +24,7 @@ from extensions import (
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
@ -70,43 +68,7 @@ def create_flask_app_with_configs() -> Flask:
|
||||
|
||||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
|
||||
app.secret_key = app.config["SECRET_KEY"]
|
||||
|
||||
log_handlers = None
|
||||
log_file = app.config.get("LOG_FILE")
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers = [
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
log_tz = app.config.get("LOG_TZ")
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
handler.formatter.converter = time_converter
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
@ -117,6 +79,7 @@ def create_app() -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_logging.init_app(app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
@ -187,7 +150,7 @@ def register_blueprints(app):
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
@ -198,7 +161,7 @@ def register_blueprints(app):
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
|
||||
@ -277,6 +277,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.TENCENT,
|
||||
VectorType.BAIDU,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
page = 1
|
||||
while True:
|
||||
|
||||
@ -32,6 +32,21 @@ class SecurityConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY_ENABLE: bool = Field(
|
||||
description="Whether to enable admin api key for authentication",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY: Optional[str] = Field(
|
||||
description="admin api key for authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
@ -304,6 +319,16 @@ class LoggingConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FILE_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
|
||||
default=20,
|
||||
)
|
||||
|
||||
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
|
||||
description="Maximum file backup count file rotation retention",
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
|
||||
@ -28,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
@ -246,5 +247,6 @@ class MiddlewareConfig(
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
20
api/configs/middleware/vdb/upstash_config.py
Normal file
20
api/configs/middleware/vdb/upstash_config.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class UpstashConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Upstash vector database
|
||||
"""
|
||||
|
||||
UPSTASH_VECTOR_URL: Optional[str] = Field(
|
||||
description="URL of the upstash server (e.g., 'https://vector.upstash.io')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UPSTASH_VECTOR_TOKEN: Optional[str] = Field(
|
||||
description="Token for authenticating with the upstash server",
|
||||
default=None,
|
||||
)
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.10.0",
|
||||
default="0.10.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -15,7 +15,9 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not os.getenv("ADMIN_API_KEY"):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -52,4 +52,39 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
max_tokens=CODE_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return code_result
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import redirect, request
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
@ -196,10 +195,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
return NotAllowedCreateWorkspace()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
@ -94,17 +94,15 @@ class OAuthCallback(Resource):
|
||||
account = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except WorkSpaceNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
|
||||
# Check account status
|
||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
|
||||
@ -619,6 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -657,6 +658,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -30,13 +30,12 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
def get(self):
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
|
||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
"file_size_limit": file_size_limit,
|
||||
"batch_count_limit": batch_count_limit,
|
||||
"image_file_size_limit": image_file_size_limit,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -41,7 +41,7 @@ class AlreadyActivateError(BaseHTTPException):
|
||||
|
||||
|
||||
class NotAllowedCreateWorkspace(BaseHTTPException):
|
||||
error_code = "unauthorized"
|
||||
error_code = "not_allowed_create_workspace"
|
||||
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
code = 400
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
@ -41,24 +41,39 @@ class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
||||
parser.add_argument("sign", type=str, required=True, location="args")
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_signed_file_preview(
|
||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
|
||||
@ -42,10 +42,10 @@ class ToolFilePreviewApi(Resource):
|
||||
stream,
|
||||
mimetype=tool_file.mimetype,
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Length": str(tool_file.size),
|
||||
},
|
||||
headers={},
|
||||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class MessageListApi(Resource):
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.String),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
|
||||
@ -46,7 +46,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
@ -53,11 +53,11 @@ class BasicVariablesConfigManager:
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description", ""),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options", []),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -114,6 +114,16 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -17,10 +17,13 @@ class FileUploadConfigManager:
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
@ -240,7 +241,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
|
||||
@ -53,7 +53,7 @@ class BasedGenerateTaskPipeline:
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
@ -100,7 +100,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
|
||||
def _error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
|
||||
@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
@ -232,30 +234,30 @@ class WorkflowCycleManage:
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
session.commit()
|
||||
session.refresh(workflow_node_execution)
|
||||
|
||||
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask import Config, Flask
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
@ -44,32 +45,30 @@ class HostingConfiguration:
|
||||
moderation_config: HostedModerationConfig = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
|
||||
if config.get("EDITION") != "CLOUD":
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai(config)
|
||||
self.provider_map["openai"] = self.init_openai(config)
|
||||
self.provider_map["anthropic"] = self.init_anthropic(config)
|
||||
self.provider_map["minimax"] = self.init_minimax(config)
|
||||
self.provider_map["spark"] = self.init_spark(config)
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai(config)
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
|
||||
self.moderation_config = self.init_moderation_config(config)
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@staticmethod
|
||||
def init_azure_openai(app_config: Config) -> HostingProvider:
|
||||
def init_azure_openai() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
|
||||
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
|
||||
"base_model_name": "gpt-35-turbo",
|
||||
}
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
@ -122,31 +121,31 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_openai(self, app_config: Config) -> HostingProvider:
|
||||
def init_openai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
|
||||
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
|
||||
if dify_config.HOSTED_OPENAI_API_BASE:
|
||||
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
|
||||
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
@ -156,26 +155,26 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic(app_config: Config) -> HostingProvider:
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
|
||||
}
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
if dify_config.HOSTED_ANTHROPIC_API_BASE:
|
||||
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
@ -185,9 +184,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax(app_config: Config) -> HostingProvider:
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_MINIMAX_ENABLED"):
|
||||
if dify_config.HOSTED_MINIMAX_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -203,9 +202,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_spark(app_config: Config) -> HostingProvider:
|
||||
def init_spark() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_SPARK_ENABLED"):
|
||||
if dify_config.HOSTED_SPARK_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -221,9 +220,9 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_zhipuai(app_config: Config) -> HostingProvider:
|
||||
def init_zhipuai() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
|
||||
if dify_config.HOSTED_ZHIPUAI_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@ -239,17 +238,15 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_moderation_config(app_config: Config) -> HostedModerationConfig:
|
||||
if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
return HostedModerationConfig(
|
||||
enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",")
|
||||
)
|
||||
def init_moderation_config() -> HostedModerationConfig:
|
||||
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
|
||||
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
|
||||
|
||||
return HostedModerationConfig(enabled=False)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
|
||||
models_str = dify_config.model_dump().get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [
|
||||
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)
|
||||
|
||||
@ -8,6 +8,8 @@ from core.llm_generator.output_parser.suggested_questions_after_answer import Su
|
||||
from core.llm_generator.prompts import (
|
||||
CONVERSATION_TITLE_PROMPT,
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
@ -239,6 +241,54 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
code_language: str = "javascript",
|
||||
max_tokens: int = 1000,
|
||||
) -> dict:
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
"CODE_LANGUAGE": code_language,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = response.message.content
|
||||
return {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
@ -61,6 +61,73 @@ User Input: yo, 你今天咋样?
|
||||
User Input:
|
||||
""" # noqa: E501
|
||||
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
"You are an expert programmer. Generate code based on the following instructions:\n\n"
|
||||
"Instructions: {{INSTRUCTION}}\n\n"
|
||||
"Write the code in {{CODE_LANGUAGE}}.\n\n"
|
||||
"Please ensure that you meet the following requirements:\n"
|
||||
"1. Define a function named 'main'.\n"
|
||||
"2. The 'main' function must return a dictionary (dict).\n"
|
||||
"3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n"
|
||||
"4. The returned dictionary should contain at least one key-value pair.\n\n"
|
||||
"5. You may ONLY use the following libraries in your code: \n"
|
||||
"- json\n"
|
||||
"- datetime\n"
|
||||
"- math\n"
|
||||
"- random\n"
|
||||
"- re\n"
|
||||
"- string\n"
|
||||
"- sys\n"
|
||||
"- time\n"
|
||||
"- traceback\n"
|
||||
"- uuid\n"
|
||||
"- os\n"
|
||||
"- base64\n"
|
||||
"- hashlib\n"
|
||||
"- hmac\n"
|
||||
"- binascii\n"
|
||||
"- collections\n"
|
||||
"- functools\n"
|
||||
"- operator\n"
|
||||
"- itertools\n\n"
|
||||
"Example:\n"
|
||||
"def main(arg1: str, arg2: int) -> dict:\n"
|
||||
" return {\n"
|
||||
' "result": arg1 * arg2,\n'
|
||||
" }\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
|
||||
"- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n"
|
||||
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
|
||||
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
|
||||
"- Always use the format return {'result': ...} for the output.\n\n"
|
||||
"Generated Code:\n"
|
||||
)
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
"You are an expert programmer. Generate code based on the following instructions:\n\n"
|
||||
"Instructions: {{INSTRUCTION}}\n\n"
|
||||
"Write the code in {{CODE_LANGUAGE}}.\n\n"
|
||||
"Please ensure that you meet the following requirements:\n"
|
||||
"1. Define a function named 'main'.\n"
|
||||
"2. The 'main' function must return an object.\n"
|
||||
"3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n"
|
||||
"4. The returned object should contain at least one key-value pair.\n\n"
|
||||
"5. The returned object should always be in the format: {result: ...}\n\n"
|
||||
"Example:\n"
|
||||
"function main(arg1, arg2) {\n"
|
||||
" return {\n"
|
||||
" result: arg1 * arg2\n"
|
||||
" };\n"
|
||||
"}\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n"
|
||||
"- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n"
|
||||
"- The code should start immediately after this instruction, without any preceding newlines or spaces.\n"
|
||||
"- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n"
|
||||
"Generated Code:\n"
|
||||
)
|
||||
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.file.models import FileType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -98,8 +99,9 @@ class TokenBufferMemory:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@ -218,7 +218,7 @@ For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` param
|
||||
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -205,7 +205,7 @@ provider_credential_schema:
|
||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例A模型支持`top_k`,B模型不支持`top_k`,那么我们需要在这里动态生成模型参数的Schema,如下所示:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
- claude-3-opus-20240229
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
model: claude-3-5-sonnet-20241022
|
||||
label:
|
||||
en_US: claude-3-5-sonnet-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -294,7 +294,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -148,7 +148,7 @@ class AzureRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
model: anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: us.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -118,7 +118,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -189,7 +189,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
@ -212,7 +212,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
completion_model = None
|
||||
if credentials["completion_type"] == "chat_completion":
|
||||
completion_model = LLMMode.CHAT.value
|
||||
|
||||
@ -73,7 +73,7 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -115,7 +115,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@ -61,7 +61,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
||||
|
||||
return response.text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -62,7 +62,7 @@ class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
@ -193,7 +194,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -408,7 +408,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -157,7 +157,7 @@ class SageMakerRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -111,7 +111,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -180,7 +180,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -159,7 +159,7 @@ class SageMakerText2SpeechModel(TTSModel):
|
||||
|
||||
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -40,7 +40,7 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["mode"] = "chat"
|
||||
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@ -50,7 +50,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@ -535,7 +535,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Architecture for defining customizable models
|
||||
|
||||
|
||||
@ -76,3 +76,4 @@ pricing:
|
||||
output: '0.12'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -26,7 +26,7 @@ parameter_rules:
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
|
||||
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -26,7 +26,7 @@ parameter_rules:
|
||||
type: int
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 2000
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from httpx import Response, post
|
||||
from yarl import URL
|
||||
@ -109,7 +110,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
||||
return text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -0,0 +1,55 @@
|
||||
model: claude-3-5-sonnet-v2@20241022
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet v2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
@ -298,7 +299,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
chunks = client.stream_chat(prompt_messages, **req_params)
|
||||
return _handle_stream_chat_response(chunks)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
@ -321,7 +321,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -142,7 +142,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -129,7 +129,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
|
||||
return response["text"]
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -184,7 +184,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -116,7 +116,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
"""
|
||||
return self._tts_invoke_streaming(model, credentials, content_text, voice)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
@ -33,7 +33,7 @@ class PromptTemplateParser:
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables:
|
||||
if remove_template_variables and isinstance(value, str):
|
||||
return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl)
|
||||
return value
|
||||
|
||||
|
||||
@ -428,14 +428,13 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
if not dataset.index_struct_dict:
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return QdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
root_path=current_app.config.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
|
||||
0
api/core/rag/datasource/vdb/upstash/__init__.py
Normal file
0
api/core/rag/datasource/vdb/upstash/__init__.py
Normal file
129
api/core/rag/datasource/vdb/upstash/upstash_vector.py
Normal file
129
api/core/rag/datasource/vdb/upstash/upstash_vector.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from upstash_vector import Index, Vector
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class UpstashVectorConfig(BaseModel):
|
||||
url: str
|
||||
token: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["url"]:
|
||||
raise ValueError("Upstash URL is required")
|
||||
if not values["token"]:
|
||||
raise ValueError("Upstash Token is required")
|
||||
return values
|
||||
|
||||
|
||||
class UpstashVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: UpstashVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self._table_name = collection_name
|
||||
self.index = Index(url=config.url, token=config.token)
|
||||
|
||||
def _get_index_dimension(self) -> int:
|
||||
index_info = self.index.info()
|
||||
if index_info and index_info.dimension:
|
||||
return index_info.dimension
|
||||
else:
|
||||
return 1536
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
vectors = [
|
||||
Vector(
|
||||
id=str(uuid4()),
|
||||
vector=embedding,
|
||||
metadata=doc.metadata,
|
||||
data=doc.page_content,
|
||||
)
|
||||
for doc, embedding in zip(documents, embeddings)
|
||||
]
|
||||
self.index.upsert(vectors=vectors)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
response = self.get_ids_by_metadata_field("doc_id", id)
|
||||
return len(response) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
item_ids = []
|
||||
for doc_id in ids:
|
||||
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
|
||||
if id:
|
||||
item_ids += ids
|
||||
self._delete_by_ids(ids=item_ids)
|
||||
|
||||
def _delete_by_ids(self, ids: list[str]) -> None:
|
||||
if ids:
|
||||
self.index.delete(ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
query_result = self.index.query(
|
||||
vector=[1.001 * i for i in range(self._get_index_dimension())],
|
||||
include_metadata=True,
|
||||
top_k=1000,
|
||||
filter=f"{key} = '{value}'",
|
||||
)
|
||||
return [result.id for result in query_result]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
metadata = record.metadata
|
||||
text = record.data
|
||||
score = record.score
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
self.index.reset()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.UPSTASH
|
||||
|
||||
|
||||
class UpstashVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
|
||||
|
||||
return UpstashVector(
|
||||
collection_name=collection_name,
|
||||
config=UpstashVectorConfig(
|
||||
url=dify_config.UPSTASH_VECTOR_URL,
|
||||
token=dify_config.UPSTASH_VECTOR_TOKEN,
|
||||
),
|
||||
)
|
||||
@ -111,6 +111,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case VectorType.UPSTASH:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -18,3 +18,4 @@ class VectorType(str, Enum):
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
UPSTASH = "upstash"
|
||||
|
||||
@ -21,6 +21,7 @@ from core.rag.extractor.unstructured.unstructured_eml_extractor import Unstructu
|
||||
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pdf_extractor import UnstructuredPDFExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
|
||||
@ -102,10 +103,10 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = UnstructuredPDFExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension in {".md", ".markdown"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
if is_automatic
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
)
|
||||
@ -116,17 +117,17 @@ class ExtractProcessor:
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == ".msg":
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".eml":
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".ppt":
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".pptx":
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".xml":
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
else:
|
||||
# txt
|
||||
extractor = (
|
||||
|
||||
@ -10,24 +10,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load eml files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
|
||||
@ -19,15 +19,23 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -24,19 +24,21 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -14,15 +14,21 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPDFExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
api_url: Unstructured API URL
|
||||
|
||||
api_key: Unstructured API Key
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(
|
||||
filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto"
|
||||
)
|
||||
else:
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
|
||||
elements = partition_pdf(filename=self._file_path, strategy="auto")
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -7,7 +7,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load ppt files.
|
||||
|
||||
|
||||
Args:
|
||||
@ -21,9 +21,12 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
raise NotImplementedError("Unstructured API Url is not configured")
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
|
||||
@ -7,22 +7,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load pptx files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
|
||||
@ -7,22 +7,29 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load xml files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -18,6 +18,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -109,9 +110,10 @@ class WordExtractor(BaseExtractor):
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=image_ext,
|
||||
mime_type=mime_type,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
|
||||
@ -37,7 +37,7 @@ parameters:
|
||||
- value: mixtral-8x7b
|
||||
label:
|
||||
en_US: Mixtral
|
||||
default: gpt-3.5
|
||||
default: gpt-4o-mini
|
||||
label:
|
||||
en_US: Choose Model
|
||||
zh_Hans: 选择模型
|
||||
|
||||
BIN
api/core/tools/provider/builtin/feishu_base/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/feishu_base/_assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.1 KiB |
@ -1,47 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="240px" height="240px" viewBox="0 0 240 240" enable-background="new 0 0 240 240" xml:space="preserve"> <image id="image0" width="240" height="240" x="0" y="0"
|
||||
xlink:href="
|
||||
AAB1MAAA6mAAADqYAAAXcJy6UTwAAAINUExURQAAAJ9g6qNg6qRd6KNe6aVg6J9g6p9g359a6qRd
|
||||
66Ne66Re66Rd66Rd66Re7Z9g559g76Ne7aRe6qJd76Fe66Ne66Re66Vd7KNc6aRe66Nd66Re7KRd
|
||||
66Nc6p9g36Re7KRf7KRe7KNe7KNc76Va6qVd66Ne66Re66Nd6qRb6KRe66Zd66Ve7KJd6aNe66Fe
|
||||
7KRd6p9Y56Re66Je6qRb6qNe66Re7KNc56Jf6qNe6qRe6qFe6KNe6qRf66Rf66Ne7KNe6qRe6qRd
|
||||
66Rd7KFe6aNc6qNe6qJd7Z9b6KNd659Z7KNd66Ne6aRe66Re7KNc66Nd66Ne66Jd6qNe6qNc6aJd
|
||||
6qNd6qRe66Jd7KFe6qJe66Re6qNe66Jd66Jd6aRd6qRd66Va759V6qVd6qFe56Ne6KJd6KRe66Zj
|
||||
67eB79Gu9eLM+OjX+uXS+dSz9bJ37s6p9LV87q9y7de49vn1/f///9q99qlo7NGu9Kxt7d/H+Pz5
|
||||
/vz6/qdj68CQ8e7h+/Pq/Maa8rqG8NzC9/Pr/Myk8+DH+OXR+cyk9Pn1/ujW+u3g+7iB7+vc+r2L
|
||||
8N3C9//+//bw/cmf87uG8OPM+ebR+axt7Muk8/Dm/Mif8/n0/c+p9MOV8fHm++7g+/n0/vbv/bqG
|
||||
78CQ8Kxu7cCR8ePM+LqH8MCP8bV87+LM+bqF8Muk9OLL+NSy9bqF75VHsr4AAABndFJOUwAYSHCA
|
||||
WDAQMJfn57+PVyAQb98/gO/mllDPr1/ebwiHn+6eLzCvf/63OL4/xkfOT58g9odfZ65Aj3fHT2+X
|
||||
j3f3z7efgEjPRzi3KKd/7sZAv+6fri+P3rY3X77G5j9oxn4fGGCAbzcMFjqxAAAAAWJLR0R1qGqY
|
||||
+wAAAAd0SU1FB+gGDQkfBmABjhYAAAXYSURBVHja7d35WxtFGAfwbUGDtRWKilAwFg+2VHtwqa31
|
||||
qlfV1qta7/vYkDQhJFnTliQWhFYoDaVY8ECw3vffaHgefZ5CZvbZIbMz83a/35/ngfkwx+5m3iyW
|
||||
hSAIgiAIgiAIgiAIgiAIQjAbNtbVa0vdNRuvjSjlNly3ydGd6zdvuaFRDbdpa7Nu7f9p3nJjU/De
|
||||
m27W7bwyLbcEbW5t021cm7atQc7thm26fYw01wdHbjdmAa9Kx61NwXijt+mm8dLWHgi5cbtuGDed
|
||||
t98RAPjOu3S7PBLEIBsNdjq7pIvNBjv2ju5wgR17593hAjvOPbtCBnZ27wkZWK6YAljqOqYAlrpX
|
||||
kwA79t6ecIGd3r6mcIGdjv6QgaUtYypgaZOaDFjWpKYDlrRT0wE7LQMhA9v3yvgEhBBYziqmBJay
|
||||
iimBpQwxKbDdFw0X2Nl8X8jAEuY0LbB9f80H58TAtT9C0AJLmNPEwL37QgaufRFTA29rCBfY2V/r
|
||||
OQQ18O4HQgau+aGYGrjzQMjAdhfAVzm41idEcuBa7zwANjwAAwwwwKQCcIDg2GA8cXx1kqnBGK/5
|
||||
UHpta16SqWETwZlszq1O7tMhZuv8CVZrXnKpk8aBh05xOptgDY+Y13VHsj7FysB8QSFerG5e+kzI
|
||||
67qnRw0DDye4fT0+Vv3nyQp6XffzjFng0XH+Apyoaj10Rhh8dtAs8BeT3K4W0lWtx6aEwedKlMHT
|
||||
Z4XB58tmgUvnuF0dr95vMjPC4IS/a7EysAeB1dV0QdB7IZs3C8zftWYvMloLz2nGVq8ZXJyYZfZ0
|
||||
Ms4cmjmxbcvvZVjpvfTcl9UbV+FSmXMzPTwz6ZtbuDTntxNqn5bmp+dWZ3reo3VmzmemMzHfXcDj
|
||||
IcAAA0wqAAMMMMCkAjDAAANMKgADDDDApAIwwBLB+XI6vjrp0kmB1rx4/RSN4PwC67BlfCHPbB0r
|
||||
nRc4ePjq67zPXqgD82oeCuzqjIuzAl7X/eaET7E68ASvxiP3LaPGQ/BoiXyNRzEl6DXvuNSjxoNx
|
||||
IB66Go91lDyErsbDNPDidyGb0h5jxqjxKMbp13hw5/RkmnGcLXxZWvI3o1XeeMSXmT0dSbGGpjjx
|
||||
vZB3mV0pohPs5EeXLlTPxNOcW8viIqM1N7wbVK3gCvnywtrb/h/4RR4CDw+L/gpLlYNNCMAAAwww
|
||||
qQAMMMAAkwrAAAMMMKkADDDAAJMKwIGCfxwcW5OMUGte/H9KqxIcKyerD4xyyZ+K7ObTP/s/Xiok
|
||||
fzEPnE+zqzZy7OqMskhNC+dLyHrB3BKAcVZ1hkeFBDvGfWHaQ8D6SvyE2ItaXHckXvTVD0MPxNfx
|
||||
0gPTDsQ96q6uztdaeNR4uPGq1ut5ccmvZoE9ypYYO+xvvwuDp/4wC+xRw8CoVlhHYZppLx/il6kU
|
||||
WDUPpSVB8PKf/vqh7jrM3XjZLxDjVITwwq4U0Qp2MlmWYWSG84o4IfFyat5nL5Q+PAynT02tTiJ1
|
||||
Oea/NS9//f2P7xd54PEQYIABJhWAAQYYYFIBWDSN23UTFIM3PKiboBgceUg3QTHYevgR3QbF4IFN
|
||||
ug2Kwa2P6jYoBncf1G1QDO55TLdBMTj6OKldS8I/5G1v1o1QDKa1iCWAad16SADTmtMywKTmtAww
|
||||
qX1aBth6Yr9uhmIwpW1LCtja86Ruh2IwoSGWA7aeIrOKJYGjXb26JWrB1qGndUsUg63+Dt0UxeBI
|
||||
H41JLQ1sPfOsbotiMJGdWiKYxjKWCY4eaNHNUQu2IoeP6PaoBVMQywUTEEsGW5F9hq9j2eDK1ek5
|
||||
WzdKLdg6tMNkcQBgsxey3ReVDras51+wdcO44K4AvCuD/KKh5I7+QMCW1fCSkWT76MsBgQ0lHzks
|
||||
f8+6gvzKMcPI9quvBeitJPL6GyYNs/3mrmC9K+kZOPqWbQba3vl28N6VRBve2ftuRW1rhFd+97H3
|
||||
3lfj/S8fdLd+WFevKx99/EkQdxwIgiAIgiAIgiAIgiAIgiBB519+T+5Fl+ldNwAAACV0RVh0ZGF0
|
||||
ZTpjcmVhdGUAMjAyNC0wNi0xM1QwOTozMTowNiswMDowMPHqs70AAAAldEVYdGRhdGU6bW9kaWZ5
|
||||
ADIwMjQtMDYtMTNUMDk6MzE6MDYrMDA6MDCAtwsBAAAAKHRFWHRkYXRlOnRpbWVzdGFtcAAyMDI0
|
||||
LTA2LTEzVDA5OjMxOjA2KzAwOjAw16Iq3gAAAABJRU5ErkJggg==" />
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.7 KiB |
@ -1,8 +1,7 @@
|
||||
from core.tools.provider.builtin.feishu_base.tools.get_tenant_access_token import GetTenantAccessTokenTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.utils.feishu_api_utils import auth
|
||||
|
||||
|
||||
class FeishuBaseProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
GetTenantAccessTokenTool()
|
||||
pass
|
||||
auth(credentials)
|
||||
|
||||
@ -5,10 +5,32 @@ identity:
|
||||
en_US: Feishu Base
|
||||
zh_Hans: 飞书多维表格
|
||||
description:
|
||||
en_US: Feishu Base
|
||||
zh_Hans: 飞书多维表格
|
||||
icon: icon.svg
|
||||
en_US: |
|
||||
Feishu base, requires the following permissions: bitable:app.
|
||||
zh_Hans: |
|
||||
飞书多维表格,需要开通以下权限: bitable:app。
|
||||
icon: icon.png
|
||||
tags:
|
||||
- social
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
app_id:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: APP ID
|
||||
placeholder:
|
||||
en_US: Please input your feishu app id
|
||||
zh_Hans: 请输入你的飞书 app id
|
||||
help:
|
||||
en_US: Get your app_id and app_secret from Feishu
|
||||
zh_Hans: 从飞书获取您的 app_id 和 app_secret
|
||||
url: https://open.larkoffice.com/app
|
||||
app_secret:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: APP Secret
|
||||
placeholder:
|
||||
en_US: Please input your app secret
|
||||
zh_Hans: 请输入你的飞书 app secret
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AddBaseRecordTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records"
|
||||
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
fields = tool_parameters.get("fields", "")
|
||||
if not fields:
|
||||
return self.create_text_message("Invalid parameter fields")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {"fields": json.loads(fields)}
|
||||
|
||||
try:
|
||||
res = httpx.post(
|
||||
url.format(app_token=app_token, table_id=table_id),
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to add base record, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to add base record. {}".format(e))
|
||||
@ -1,66 +0,0 @@
|
||||
identity:
|
||||
name: add_base_record
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Add Base Record
|
||||
zh_Hans: 在多维表格数据表中新增一条记录
|
||||
description:
|
||||
human:
|
||||
en_US: Add Base Record
|
||||
zh_Hans: |
|
||||
在多维表格数据表中新增一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/create
|
||||
llm: Add a new record in the multidimensional table data table.
|
||||
parameters:
|
||||
- name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: token
|
||||
zh_Hans: 凭证
|
||||
human_description:
|
||||
en_US: API access token parameter, tenant_access_token or user_access_token
|
||||
zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token
|
||||
llm_description: API access token parameter, tenant_access_token or user_access_token
|
||||
form: llm
|
||||
|
||||
- name: app_token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: app_token
|
||||
zh_Hans: 多维表格
|
||||
human_description:
|
||||
en_US: bitable app token
|
||||
zh_Hans: 多维表格的唯一标识符 app_token
|
||||
llm_description: bitable app token
|
||||
form: llm
|
||||
|
||||
- name: table_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: table_id
|
||||
zh_Hans: 多维表格的数据表
|
||||
human_description:
|
||||
en_US: bitable table id
|
||||
zh_Hans: 多维表格数据表的唯一标识符 table_id
|
||||
llm_description: bitable table id
|
||||
form: llm
|
||||
|
||||
- name: fields
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: fields
|
||||
zh_Hans: 数据表的列字段内容
|
||||
human_description:
|
||||
en_US: The fields of the Base data table are the columns of the data table.
|
||||
zh_Hans: |
|
||||
要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"}
|
||||
当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。
|
||||
不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure
|
||||
llm_description: |
|
||||
要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"}
|
||||
当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。
|
||||
不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure
|
||||
form: llm
|
||||
@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class AddRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_id = tool_parameters.get("table_id")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
records = tool_parameters.get("records")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
res = client.add_records(app_token, table_id, table_name, records, user_id_type)
|
||||
return self.create_json_message(res)
|
||||
@ -0,0 +1,91 @@
|
||||
identity:
|
||||
name: add_records
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Add Records
|
||||
zh_Hans: 新增多条记录
|
||||
description:
|
||||
human:
|
||||
en_US: Add Multiple Records to Multidimensional Table
|
||||
zh_Hans: 在多维表格数据表中新增多条记录
|
||||
llm: A tool for adding multiple records to a multidimensional table. (在多维表格数据表中新增多条记录)
|
||||
parameters:
|
||||
- name: app_token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: app_token
|
||||
zh_Hans: app_token
|
||||
human_description:
|
||||
en_US: Unique identifier for the multidimensional table, supports inputting document URL.
|
||||
zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。
|
||||
llm_description: 多维表格的唯一标识符,支持输入文档 URL。
|
||||
form: llm
|
||||
|
||||
- name: table_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: table_id
|
||||
zh_Hans: table_id
|
||||
human_description:
|
||||
en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously.
|
||||
zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。
|
||||
llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。
|
||||
form: llm
|
||||
|
||||
- name: table_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: table_name
|
||||
zh_Hans: table_name
|
||||
human_description:
|
||||
en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously.
|
||||
zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。
|
||||
llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。
|
||||
form: llm
|
||||
|
||||
- name: records
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: records
|
||||
zh_Hans: 记录列表
|
||||
human_description:
|
||||
en_US: |
|
||||
List of records to be added in this request. Example value: [{"multi-line-text":"text content","single_select":"option 1","date":1674206443000}]
|
||||
For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure).
|
||||
zh_Hans: |
|
||||
本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。
|
||||
当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。
|
||||
llm_description: |
|
||||
本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。
|
||||
当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。
|
||||
form: llm
|
||||
|
||||
- name: user_id_type
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: open_id
|
||||
label:
|
||||
en_US: open_id
|
||||
zh_Hans: open_id
|
||||
- value: union_id
|
||||
label:
|
||||
en_US: union_id
|
||||
zh_Hans: union_id
|
||||
- value: user_id
|
||||
label:
|
||||
en_US: user_id
|
||||
zh_Hans: user_id
|
||||
default: "open_id"
|
||||
label:
|
||||
en_US: user_id_type
|
||||
zh_Hans: 用户 ID 类型
|
||||
human_description:
|
||||
en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id.
|
||||
zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。
|
||||
form: form
|
||||
@ -1,41 +1,18 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class CreateBaseTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps"
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
name = tool_parameters.get("name")
|
||||
folder_token = tool_parameters.get("folder_token")
|
||||
|
||||
name = tool_parameters.get("name", "")
|
||||
folder_token = tool_parameters.get("folder_token", "")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {"name": name, "folder_token": folder_token}
|
||||
|
||||
try:
|
||||
res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to create base, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to create base. {}".format(e))
|
||||
res = client.create_base(name, folder_token)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -6,32 +6,21 @@ identity:
|
||||
zh_Hans: 创建多维表格
|
||||
description:
|
||||
human:
|
||||
en_US: Create base
|
||||
en_US: Create Multidimensional Table in Specified Directory
|
||||
zh_Hans: 在指定目录下创建多维表格
|
||||
llm: A tool for create a multidimensional table in the specified directory.
|
||||
llm: A tool for creating a multidimensional table in a specified directory. (在指定目录下创建多维表格)
|
||||
parameters:
|
||||
- name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: token
|
||||
zh_Hans: 凭证
|
||||
human_description:
|
||||
en_US: API access token parameter, tenant_access_token or user_access_token
|
||||
zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token
|
||||
llm_description: API access token parameter, tenant_access_token or user_access_token
|
||||
form: llm
|
||||
|
||||
- name: name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name
|
||||
zh_Hans: name
|
||||
zh_Hans: 多维表格 App 名字
|
||||
human_description:
|
||||
en_US: Base App Name
|
||||
zh_Hans: 多维表格App名字
|
||||
llm_description: Base App Name
|
||||
en_US: |
|
||||
Name of the multidimensional table App. Example value: "A new multidimensional table".
|
||||
zh_Hans: 多维表格 App 名字,示例值:"一篇新的多维表格"。
|
||||
llm_description: 多维表格 App 名字,示例值:"一篇新的多维表格"。
|
||||
form: llm
|
||||
|
||||
- name: folder_token
|
||||
@ -39,9 +28,15 @@ parameters:
|
||||
required: false
|
||||
label:
|
||||
en_US: folder_token
|
||||
zh_Hans: 多维表格App归属文件夹
|
||||
zh_Hans: 多维表格 App 归属文件夹
|
||||
human_description:
|
||||
en_US: Base App home folder. The default is empty, indicating that Base will be created in the cloud space root directory.
|
||||
zh_Hans: 多维表格App归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。
|
||||
llm_description: Base App home folder. The default is empty, indicating that Base will be created in the cloud space root directory.
|
||||
en_US: |
|
||||
Folder where the multidimensional table App belongs. Default is empty, meaning the table will be created in the root directory of the cloud space. Example values: Fa3sfoAgDlMZCcdcJy1cDFg8nJc or https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc.
|
||||
The folder_token must be an existing folder and supports inputting folder token or folder URL.
|
||||
zh_Hans: |
|
||||
多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。
|
||||
folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。
|
||||
llm_description: |
|
||||
多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。
|
||||
folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。
|
||||
form: llm
|
||||
|
||||
@ -1,48 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CreateBaseTableTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables"
|
||||
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
name = tool_parameters.get("name", "")
|
||||
|
||||
fields = tool_parameters.get("fields", "")
|
||||
if not fields:
|
||||
return self.create_text_message("Invalid parameter fields")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {"table": {"name": name, "fields": json.loads(fields)}}
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to create base table, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to create base table. {}".format(e))
|
||||
@ -1,106 +0,0 @@
|
||||
identity:
|
||||
name: create_base_table
|
||||
author: Doug Lea
|
||||
label:
|
||||
en_US: Create Base Table
|
||||
zh_Hans: 多维表格新增一个数据表
|
||||
description:
|
||||
human:
|
||||
en_US: Create base table
|
||||
zh_Hans: |
|
||||
多维表格新增一个数据表,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table/create
|
||||
llm: A tool for add a new data table to the multidimensional table.
|
||||
parameters:
|
||||
- name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: token
|
||||
zh_Hans: 凭证
|
||||
human_description:
|
||||
en_US: API access token parameter, tenant_access_token or user_access_token
|
||||
zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token
|
||||
llm_description: API access token parameter, tenant_access_token or user_access_token
|
||||
form: llm
|
||||
|
||||
- name: app_token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: app_token
|
||||
zh_Hans: 多维表格
|
||||
human_description:
|
||||
en_US: bitable app token
|
||||
zh_Hans: 多维表格的唯一标识符 app_token
|
||||
llm_description: bitable app token
|
||||
form: llm
|
||||
|
||||
- name: name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name
|
||||
zh_Hans: name
|
||||
human_description:
|
||||
en_US: Multidimensional table data table name
|
||||
zh_Hans: 多维表格数据表名称
|
||||
llm_description: Multidimensional table data table name
|
||||
form: llm
|
||||
|
||||
- name: fields
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: fields
|
||||
zh_Hans: fields
|
||||
human_description:
|
||||
en_US: Initial fields of the data table
|
||||
zh_Hans: |
|
||||
数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。
|
||||
field_name:字段名;
|
||||
type: 字段类型;可选值有
|
||||
1:多行文本
|
||||
2:数字
|
||||
3:单选
|
||||
4:多选
|
||||
5:日期
|
||||
7:复选框
|
||||
11:人员
|
||||
13:电话号码
|
||||
15:超链接
|
||||
17:附件
|
||||
18:单向关联
|
||||
20:公式
|
||||
21:双向关联
|
||||
22:地理位置
|
||||
23:群组
|
||||
1001:创建时间
|
||||
1002:最后更新时间
|
||||
1003:创建人
|
||||
1004:修改人
|
||||
1005:自动编号
|
||||
llm_description: |
|
||||
数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。
|
||||
field_name:字段名;
|
||||
type: 字段类型;可选值有
|
||||
1:多行文本
|
||||
2:数字
|
||||
3:单选
|
||||
4:多选
|
||||
5:日期
|
||||
7:复选框
|
||||
11:人员
|
||||
13:电话号码
|
||||
15:超链接
|
||||
17:附件
|
||||
18:单向关联
|
||||
20:公式
|
||||
21:双向关联
|
||||
22:地理位置
|
||||
23:群组
|
||||
1001:创建时间
|
||||
1002:最后更新时间
|
||||
1003:创建人
|
||||
1004:修改人
|
||||
1005:自动编号
|
||||
form: llm
|
||||
@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class CreateTableTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
default_view_name = tool_parameters.get("default_view_name")
|
||||
fields = tool_parameters.get("fields")
|
||||
|
||||
res = client.create_table(app_token, table_name, default_view_name, fields)
|
||||
return self.create_json_message(res)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user