Compare commits

..

58 Commits

Author SHA1 Message Date
Yi
9125971da2 fix: margin in rerank switch 2024-10-09 17:59:42 +08:00
Yi
6f9d6cd3e1 fix: edit external knowledge api warning message 2024-09-30 14:23:51 +08:00
Yi
f6074b6545 fix: chatbot rerank popup logics 2024-09-30 14:02:23 +08:00
Yi
fd4d7e9002 fix: edit dataset card from datasets page, naming 2024-09-30 11:58:46 +08:00
Yi
383a60a7df fix: rerank open logics added to chatgpt, modified the hit detail modal styling 2024-09-29 18:33:27 +08:00
Yi
918df23f64 Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-29 17:54:33 +08:00
Yi
bc81d2d30d fix: styling issues and create knowledge api from the knowledge base creation page 2024-09-29 17:26:49 +08:00
89290183c6 add score threshold enabled 2024-09-29 15:36:59 +08:00
Yi
6508e7e1e4 fix: retrieval config for rerank cases 2024-09-29 14:52:47 +08:00
1955de2463 add tidb on qdrant whitelist and batch job 2024-09-29 14:33:28 +08:00
4ee3743b20 add tidb on qdrant whitelist and batch job 2024-09-29 11:57:15 +08:00
Yi
e5d8c07508 add helper text 2024-09-29 11:12:03 +08:00
Yi
69c0f3f2ad fix: default selection issue & trigger retrieval setting unintentionally 2024-09-28 14:13:02 +08:00
Yi
b92fced974 Merge branch 'main' into feat/external-knowledge-api 2024-09-27 22:39:04 +08:00
Yi
644ab2df35 feat: add new external knowledge api from the knowledge create page 2024-09-27 22:38:13 +08:00
020766a5e8 Merge branch 'main' into feat/external-knowledge-api
# Conflicts:
#	api/poetry.lock
2024-09-27 17:49:40 +08:00
Yi
c9e3a9e56a feat: add external api from the create external knowledge page 2024-09-27 17:44:01 +08:00
9c9352bc73 update to external knowledge api 2024-09-27 16:17:45 +08:00
2a1cba9f4d Merge remote-tracking branch 'origin/feat/external-knowledge-api' into feat/external-knowledge-api 2024-09-27 16:03:18 +08:00
8e73844781 update to external knowledge api 2024-09-27 16:02:59 +08:00
Yi
5554cf7b20 feat: connect knowledge base to app 2024-09-27 15:50:22 +08:00
Yi
1597f34471 Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-27 10:11:19 +08:00
Yi
1c7cb3fbc0 feat: external knowledge base 2024-09-27 00:33:56 +08:00
611f0fb3f6 update to external knowledge api 2024-09-26 16:38:53 +08:00
Yi
ff0260e564 fix: minor issues 2024-09-26 10:23:06 +08:00
Yi
85deb9d7af Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-26 01:01:30 +08:00
Yi
cfa4825073 feat: external knowledge api crud frontend & connect external knowledge base 2024-09-26 01:00:49 +08:00
5fa86074ed update to external knowledge api 2024-09-25 13:31:15 +08:00
Yi
d6c604a356 Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-25 13:05:57 +08:00
c927c97310 update to external knowledge api 2024-09-25 12:37:23 +08:00
a69dcb8bee add external_retrieval_model 2024-09-25 10:57:12 +08:00
02b06c420e add external_retrieval_model 2024-09-24 23:52:01 +08:00
a258f8dfdf remove description 2024-09-24 23:32:23 +08:00
a53b4fb2ff remove description 2024-09-24 22:28:23 +08:00
680c1bd41d remove description 2024-09-24 21:37:55 +08:00
Yi
b9b8ec1758 Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-24 20:09:07 +08:00
6452c34818 external knowledge api 2024-09-24 19:54:17 +08:00
Yi
2655dd2026 Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api 2024-09-24 19:33:15 +08:00
30dc137ccc Merge branch 'main' into feat/external-knowledge-api
# Conflicts:
#	api/core/rag/retrieval/dataset_retrieval.py
2024-09-24 18:03:14 +08:00
573b61b7e8 External knowledge api 2024-09-24 18:02:03 +08:00
089da063d4 External knowledge api 2024-09-24 18:00:45 +08:00
ed92c90a40 External knowledge api 2024-09-24 17:52:16 +08:00
Yi
fbedd08292 feat: add external api 2024-09-23 23:34:01 +08:00
19c526120c external knowledge api 2024-09-19 17:07:33 +08:00
37f7d5732a external knowledge api 2024-09-18 15:29:30 +08:00
dcb033d221 Merge branch 'main' into feat/external-knowledge
# Conflicts:
#	api/core/rag/datasource/retrieval_service.py
#	api/models/dataset.py
#	api/services/dataset_service.py
2024-09-18 14:40:43 +08:00
9f894bb3b3 external knowledge api 2024-09-18 14:36:51 +08:00
89e81873c4 merge error 2024-09-13 09:49:24 +08:00
9ca0e56a8a external dataset binding 2024-09-11 16:59:19 +08:00
e7c77d961b Merge branch 'main' into feat/external-knowledge
# Conflicts:
#	api/controllers/console/auth/data_source_oauth.py
2024-09-09 15:54:43 +08:00
a63e15081f update nltk version 2024-08-23 16:43:47 +08:00
0724640bbb fix rerank mode is none 2024-08-22 15:36:47 +08:00
cb70e12827 fix rerank mode is none 2024-08-22 15:33:43 +08:00
067b956b2c merge migration 2024-08-21 16:25:18 +08:00
e7762b731c external knowledge 2024-08-20 16:18:35 +08:00
f6c8390b0b external knowledge 2024-08-20 12:47:51 +08:00
4fd57929df Merge branch 'main' into feat/external-knowledge 2024-08-20 12:46:37 +08:00
517cdb2ca4 add external knowledge 2024-08-20 11:13:29 +08:00
174 changed files with 4417 additions and 2910 deletions

View File

@ -125,7 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@ -291,4 +291,4 @@ POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

@ -1,15 +1,6 @@
from typing import Annotated, Optional
from pydantic import (
AliasChoices,
Field,
HttpUrl,
NegativeInt,
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
)
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@ -471,11 +462,6 @@ class MailConfig(BaseSettings):
default=False,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
)
class RagEtlConfig(BaseSettings):
"""
@ -612,33 +598,6 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
default=False,
)
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
description="whether to enable email password login",
default=True,
)
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
description="whether to enable github/google oauth login",
default=False,
)
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_HOURS: PositiveFloat = Field(
description="expiry time in hours for email code login token",
default=1 / 12,
)
ALLOW_REGISTER: bool = Field(
description="whether to enable register",
default=True,
)
ALLOW_CREATE_WORKSPACE: bool = Field(
description="whether to enable create workspace",
default=False,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -664,7 +623,6 @@ class FeatureConfig(
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
LoginConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@ -5,6 +5,7 @@ from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.external.bedrock_config import BedrockConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
@ -222,5 +223,6 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
BedrockConfig,
):
pass

View File

@ -0,0 +1,20 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class BedrockConfig(BaseSettings):
"""
bedrock configs
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="AWS secret access id",
default=None,
)

View File

@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
file,
hit_testing,
test_external,
website,
)
# Import explore controllers
from .explore import (

View File

@ -1,15 +1,17 @@
import base64
import datetime
import secrets
from flask import request
from flask_restful import Resource, reqparse
from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import StrLen, email, get_remote_ip, timezone
from models.account import AccountStatus, Tenant
from services.account_service import AccountService, RegisterService
from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
class ActivateCheckApi(Resource):
@ -25,18 +27,8 @@ class ActivateCheckApi(Resource):
token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
data = invitation.get("data", {})
tenant: Tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
return {
"is_valid": invitation is not None,
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
}
else:
return {"is_valid": False}
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
class ActivateApi(Resource):
@ -46,6 +38,7 @@ class ActivateApi(Resource):
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
@ -61,6 +54,15 @@ class ActivateApi(Resource):
account = invitation["account"]
account.name = args["name"]
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args["password"], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
@ -68,9 +70,7 @@ class ActivateApi(Resource):
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {"result": "success", "data": token}
return {"result": "success"}
api.add_resource(ActivateCheckApi, "/activate/check")

View File

@ -27,29 +27,5 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
description = "Too many password reset emails have been sent. Please try again in 5 minutes."
code = 429
class EmailCodeError(BaseHTTPException):
error_code = "email_code_error"
description = "Email code is invalid or expired."
code = 400
class EmailOrPasswordMismatchError(BaseHTTPException):
error_code = "email_or_password_mismatch"
description = "The email or password is mismatched."
code = 400
class EmailPasswordLoginLimitError(BaseHTTPException):
error_code = "email_code_login_limit"
description = "The account was locked for 24 hours because the password was entered too many times."
code = 429
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
error_code = "email_code_login_rate_limit_exceeded"
description = "Too many login emails have been sent. Please try again in 5 minutes."
description = "Password reset rate limit exceeded. Try again later."
code = 429

View File

@ -1,80 +1,65 @@
import base64
import logging
import secrets
from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
PasswordResetRateLimitExceededError,
)
from controllers.console.error import EmailSendIpLimitError, NotAllowedCreateWorkspace, NotAllowedRegister
from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, get_remote_ip
from libs.helper import email as email_validate
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser.add_argument("email", type=str, required=True, location="json")
args = parser.parse_args()
ip_address = get_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
email = args["email"]
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
if not email_validate(email):
raise InvalidEmailError()
account = Account.query.filter_by(email=email).first()
if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else:
language = "en-US"
# Return success to avoid revealing email registration status
logging.warning(f"Attempt to reset password for unregistered email: {email}")
account = Account.query.filter_by(email=args["email"]).first()
token = None
if account is None:
if dify_config.ALLOW_REGISTER:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
return {"result": "success"}
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
token = args["token"]
user_email = args["email"]
reset_data = AccountService.get_reset_password_data(token)
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
raise EmailCodeError()
return {"is_valid": True, "email": token_data.get("email")}
if reset_data is None:
return {"is_valid": False, "email": None}
return {"is_valid": True, "email": reset_data.get("email")}
class ForgotPasswordResetApi(Resource):
@ -107,30 +92,11 @@ class ForgotPasswordResetApi(Resource):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not dify_config.ALLOW_CREATE_WORKSPACE:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
else:
account = AccountService.create_account_and_tenant(
email=reset_data.get("email"),
name=reset_data.get("email"),
password=password_confirm,
interface_language=languages[0],
)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {"result": "success", "data": token}
return {"result": "success"}
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")

View File

@ -1,28 +1,16 @@
from typing import cast
import flask_login
from flask import redirect, request
from flask import request
from flask_restful import Resource, reqparse
import services
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
EmailOrPasswordMismatchError,
EmailPasswordLoginLimitError,
InvalidEmailError,
InvalidTokenError,
)
from controllers.console.error import EmailSendIpLimitError, NotAllowedCreateWorkspace, NotAllowedRegister
from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, get_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.account_service import AccountService, TenantService
class LoginApi(Resource):
@ -35,37 +23,15 @@ class LoginApi(Resource):
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
args = parser.parse_args()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
# todo: Verify the recaptcha
try:
if invitation:
data = invitation.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
else:
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise NotAllowedRegister()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
if not dify_config.ALLOW_REGISTER:
raise NotAllowedRegister()
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e:
return {"code": "unauthorized", "message": str(e)}, 401
token = AccountService.send_reset_password_email(email=args["email"])
return {"result": "fail", "data": token, "message": "account_not_found"}
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -75,7 +41,7 @@ class LoginApi(Resource):
}
token = AccountService.login(account, ip_address=get_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token}
@ -89,111 +55,56 @@ class LogoutApi(Resource):
return {"result": "success"}
class ResetPasswordSendEmailApi(Resource):
class ResetPasswordApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
def get(self):
# parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json')
# args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
account = AccountService.get_user_through_email(args["email"])
if account is None:
if dify_config.ALLOW_REGISTER:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
# account = {'email': args['email']}
# account = AccountService.get_by_email(args['email'])
# if account is None:
# raise ValueError('Email not found')
# new_password = AccountService.generate_password()
# AccountService.update_password(account, new_password)
return {"result": "success", "data": token}
# todo: Send email
# MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
# response = mailchimp.messages.send({
# 'message': message,
# # required for transactional email
# ' settings': {
# 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE,
# },
# })
ip_address = get_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
# Check if MSG was sent
# if response.status_code != 200:
# # handle error
# pass
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = AccountService.get_user_through_email(args["email"])
if account is None:
if dify_config.ALLOW_REGISTER:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
account = AccountService.get_user_through_email(user_email)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not dify_config.ALLOW_CREATE_WORKSPACE:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token = AccountService.login(account, ip_address=get_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token}
return {"result": "success"}
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")

View File

@ -5,18 +5,14 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from .. import api
@ -46,7 +42,6 @@ def get_oauth_providers():
class OAuthLogin(Resource):
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
@ -54,7 +49,7 @@ class OAuthLogin(Resource):
if not oauth_provider:
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
@ -67,11 +62,6 @@ class OAuthCallback(Resource):
return {"error": "Invalid provider"}, 400
code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
@ -79,27 +69,7 @@ class OAuthCallback(Resource):
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except WorkSpaceNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
account = _generate_account(provider, user_info)
# Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403
@ -109,15 +79,7 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
TenantService.create_owner_tenant_if_not_exist(account)
token = AccountService.login(account, ip_address=get_remote_ip(request))
@ -137,20 +99,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not dify_config.ALLOW_CREATE_WORKSPACE:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if not account:
if not dify_config.ALLOW_REGISTER:
raise AccountNotFoundError()
# Create account
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

View File

@ -49,7 +49,7 @@ class DatasetListApi(Resource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
provider = request.args.get("provider", default="vendor")
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
@ -57,7 +57,7 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
)
# check embedding setting
@ -110,6 +110,26 @@ class DatasetListApi(Resource):
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
parser.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -123,6 +143,9 @@ class DatasetListApi(Resource):
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -211,6 +234,33 @@ class DatasetApi(Resource):
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
@ -563,10 +613,10 @@ class DatasetRetrievalSettingApi(Resource):
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -577,7 +627,6 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
return {
"retrieval_method": [

View File

@ -0,0 +1,282 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
"has_more": len(external_knowledge_apis) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return external_knowledge_api.to_dict(), 201
class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if external_knowledge_api is None:
raise NotFound("API template not found.")
return external_knowledge_api.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def patch(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
)
return external_knowledge_api.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def delete(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 200
class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
class ExternalDatasetInitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=True, location="json")
# parser.add_argument('name', nullable=False, required=True,
# help='name is required. Name must be between 1 to 100 characters.',
# type=_validate_name)
# parser.add_argument('description', type=str, required=True, nullable=True, location='json')
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_parameter", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
# validate args
ExternalDatasetService.document_create_args_validate(
current_user.current_tenant_id, args["external_knowledge_api_id"], args["process_parameter"]
)
try:
dataset, documents, batch = ExternalDatasetService.init_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except Exception as ex:
raise ProviderNotInitializeError(ex.description)
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
class ExternalDatasetCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 201
class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
)
return response
except Exception as e:
raise InternalServerError(str(e))
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")

View File

@ -47,6 +47,7 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -57,6 +58,7 @@ class HitTestingApi(Resource):
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)

View File

@ -0,0 +1,33 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument(
"knowledge_id",
nullable=False,
required=True,
type=str,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
api.add_resource(TestExternalApi, "/retrieval")

View File

@ -38,21 +38,3 @@ class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate"
description = "Auth Token is invalid or account already activated, please check again."
code = 403
class NotAllowedCreateWorkspace(BaseHTTPException):
error_code = "unauthorized"
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
code = 400
class NotAllowedRegister(BaseHTTPException):
error_code = "unauthorized"
description = "Account not found."
code = 400
class EmailSendIpLimitError(BaseHTTPException):
error_code = "email_send_ip_limit"
description = "Too many emails have been sent from this IP address recently. Please try again later."
code = 429

View File

@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
provider = request.args.get("provider", default="vendor")
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
@ -82,6 +82,26 @@ class DatasetListApi(DatasetApiResource):
required=False,
nullable=False,
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
parser.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
args = parser.parse_args()
try:
@ -91,6 +111,9 @@ class DatasetListApi(DatasetApiResource):
indexing_technique=args["indexing_technique"],
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

View File

@ -231,8 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,

View File

@ -212,8 +212,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,

View File

@ -248,8 +248,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
if publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None

View File

@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position"),
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),

View File

@ -119,7 +119,7 @@ class ProviderConfiguration(BaseModel):
credentials = model_configuration.credentials
break
if not credentials and self.custom_configuration.provider:
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
return credentials

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -14,7 +13,7 @@ _TEXT_COLOR_MAPPING = {
}
class Callback(ABC):
class Callback:
"""
Base class for callbacks.
Only for LLM.
@ -22,7 +21,6 @@ class Callback(ABC):
raise_error: bool = False
@abstractmethod
def on_before_invoke(
self,
llm_instance: AIModel,
@ -50,7 +48,6 @@ class Callback(ABC):
"""
raise NotImplementedError()
@abstractmethod
def on_new_chunk(
self,
llm_instance: AIModel,
@ -80,7 +77,6 @@ class Callback(ABC):
"""
raise NotImplementedError()
@abstractmethod
def on_after_invoke(
self,
llm_instance: AIModel,
@ -110,7 +106,6 @@ class Callback(ABC):
"""
raise NotImplementedError()
@abstractmethod
def on_invoke_error(
self,
llm_instance: AIModel,

View File

@ -1,310 +0,0 @@
## Custom Integration of Pre-defined Models
### Introduction
After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
It is important to note that for custom models, each model connection requires a complete vendor credential.
Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
![](images/index/image-3.png)
As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
### Writing the Vendor YAML
First, we need to identify the types of models supported by the vendor we are integrating.
Currently supported model types are as follows:
- `llm` Text Generation Models
- `text_embedding` Text Embedding Models
- `rerank` Rerank Models
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
```yaml
provider: xinference #Define the vendor identifier
label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
en_US: Xorbits Inference
icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
en_US: icon_s_en.svg
icon_large: # Large icon
en_US: icon_l_en.svg
help: # Help information
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
- llm
- text-embedding
- rerank
configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
- customizable-model
provider_credential_schema:
credential_form_schemas:
```
Then, we need to determine what credentials are required to define a model in Xinference.
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
```yaml
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 语言模型
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
```
- Next, each model has its own model_name, so we need to define that here:
```yaml
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: 模型名称
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
```
- Specify the Xinference local deployment address:
```yaml
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- Each model has a unique model_uid, so we also need to define that here:
```yaml
- variable: model_uid
label:
zh_Hans: 模型UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid
```
Now, we have completed the basic definition of the vendor.
### Writing the Model Code
Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Invocation
Implement the core method for LLM invocation, supporting both stream and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool usage
:param stop: stop words
:param stream: is the response a stream
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Heres an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool usage
:return: token count
"""
```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation
Similar to vendor credentials validation, this method validates individual model credentials.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return: None
"""
```
- Model Parameter Schema
Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
- Exception Error Mapping
When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
Runtime Errors:
- `InvokeConnectionError` Connection error during invocation
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failure
- `InvokeBadRequestError` Invalid request parameters
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

View File

@ -1,173 +0,0 @@
## Predefined Model Integration
After completing the vendor integration, the next step is to integrate the models from the vendor.
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
Currently supported model types are:
- `llm` Text Generation Model
- `text_embedding` Text Embedding Model
- `rerank` Rerank Model
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
### Prepare Model YAML
```yaml
model: claude-2.1 # Model identifier
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
# This can also be omitted, in which case the model identifier will be used as the label
label:
en_US: claude-2.1
model_type: llm # Model type, claude-2.1 is an LLM
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
- agent-thought
model_properties: # Model properties
mode: chat # LLM mode, complete for text completion models, chat for conversation models
context_size: 200000 # Maximum context size
parameter_rules: # Parameter rules for the model call; only LLM requires this
- name: temperature # Parameter variable name
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
# Additional configuration parameters will override the default configuration if set
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # Display name of the parameter
zh_Hans: 取样数量
en_US: Top k
type: int # Parameter type, supports float/int/string/boolean
help: # Help information, describing the parameter's function
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false # Whether the parameter is mandatory; can be omitted
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # Default value of the parameter
min: 1 # Minimum value of the parameter, applicable to float/int only
max: 4096 # Maximum value of the parameter, applicable to float/int only
pricing: # Pricing information
input: '8.00' # Input unit price, i.e., prompt price
output: '24.00' # Output unit price, i.e., response content price
unit: '0.000001' # Price unit, meaning the above prices are per 100K
currency: USD # Price currency
```
It is recommended to prepare all model configurations before starting the implementation of the model code.
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
### Implement the Model Call Code
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Call
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface to precompute tokens, return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- Validate Model Credentials
Similar to vendor credential validation, but specific to a single model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- Map Invoke Errors
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
Runtime Errors:
- `InvokeConnectionError` Connection error
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failed
- `InvokeBadRequestError` Parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

View File

@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp
en_US: Enter your API URL
```
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
### Implementing Provider Code

View File

@ -117,7 +117,7 @@ model_credential_schema:
en_US: Enter your API Base
```
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
#### 实现供应商代码

View File

@ -40,4 +40,3 @@
- fireworks
- mixedbread
- nomic
- voyage

View File

@ -6,8 +6,6 @@
- anthropic.claude-v2:1
- anthropic.claude-3-sonnet-v1:0
- anthropic.claude-3-haiku-v1:0
- ai21.jamba-1-5-large-v1:0
- ai21.jamba-1-5-mini-v1:0
- cohere.command-light-text-v14
- cohere.command-text-v14
- cohere.command-r-plus-v1.0
@ -17,10 +15,6 @@
- meta.llama3-1-405b-instruct-v1:0
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- us.meta.llama3-2-1b-instruct-v1:0
- us.meta.llama3-2-3b-instruct-v1:0
- us.meta.llama3-2-11b-instruct-v1:0
- us.meta.llama3-2-90b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2407-v1:0

View File

@ -1,26 +0,0 @@
model: ai21.jamba-1-5-large-v1:0
label:
en_US: Jamba 1.5 Large
model_type: llm
model_properties:
mode: completion
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
min: 0.0
max: 2.0
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.002'
output: '0.008'
unit: '0.001'
currency: USD

View File

@ -1,26 +0,0 @@
model: ai21.jamba-1-5-mini-v1:0
label:
en_US: Jamba 1.5 Mini
model_type: llm
model_properties:
mode: completion
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
min: 0.0
max: 2.0
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0002'
output: '0.0004'
unit: '0.001'
currency: USD

View File

@ -63,7 +63,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "us.meta.llama3-2", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
@ -71,7 +70,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
]
@staticmethod

View File

@ -1,29 +0,0 @@
model: us.meta.llama3-2-11b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 11B Instruct
model_type: llm
features:
- vision
- tool-call
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00035'
output: '0.00035'
unit: '0.001'
currency: USD

View File

@ -1,26 +0,0 @@
model: us.meta.llama3-2-1b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 1B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.0001'
output: '0.0001'
unit: '0.001'
currency: USD

View File

@ -1,26 +0,0 @@
model: us.meta.llama3-2-3b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 3B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00015'
output: '0.00015'
unit: '0.001'
currency: USD

View File

@ -1,31 +0,0 @@
model: us.meta.llama3-2-90b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 90B Instruct
model_type: llm
features:
- tool-call
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
default: 0.9
min: 0
max: 1
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.002'
output: '0.002'
unit: '0.001'
currency: USD

View File

@ -1,23 +1,24 @@
- Qwen2.5-72B-Instruct
- Qwen2.5-7B-Instruct
- Qwen2-72B-Instruct
- Qwen2-72B-Instruct-AWQ-int4
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-7B-Instruct
- Qwen2-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen-14B-Chat-Int4
- Yi-Coder-1.5B-Chat
- Yi-Coder-9B-Chat
- Qwen2-72B-Instruct-AWQ-int4
- Yi-1_5-9B-Chat-16K
- Qwen2-7B-Instruct
- Reflection-Llama-3.1-70B
- Qwen2-72B-Instruct
- Meta-Llama-3.1-8B-Instruct
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
- chatglm3-6b
- Meta-Llama-3-8B-Instruct
- Llama3-Chinese_v2
- deepseek-v2-lite-chat
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-7B
- Qwen-14B-Chat-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- deepseek-v2-chat
- chatglm3-6b

View File

@ -1,4 +0,0 @@
- gte-Qwen2-7B-instruct
- BAAI/bge-large-en-v1.5
- BAAI/bge-large-zh-v1.5
- BAAI/bge-m3

View File

@ -2,4 +2,3 @@ model: gte-Qwen2-7B-instruct
model_type: text-embedding
model_properties:
context_size: 2048
deprecated: true

View File

@ -1,17 +1,18 @@
- Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2.5-Math-72B-Instruct
- Qwen/Qwen2.5-32B-Instruct
- Qwen/Qwen2.5-14B-Instruct
- Qwen/Qwen2.5-7B-Instruct
- Qwen/Qwen2.5-Coder-7B-Instruct
- Qwen/Qwen2.5-Math-72B-Instruct
- deepseek-ai/DeepSeek-V2.5
- Qwen/Qwen2-72B-Instruct
- Qwen/Qwen2-57B-A14B-Instruct
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- deepseek-ai/DeepSeek-V2.5
- deepseek-ai/DeepSeek-V2-Chat
- deepseek-ai/DeepSeek-Coder-V2-Instruct
- THUDM/glm-4-9b-chat
- THUDM/chatglm3-6b
- 01-ai/Yi-1.5-34B-Chat-16K
- 01-ai/Yi-1.5-9B-Chat-16K
- 01-ai/Yi-1.5-6B-Chat
@ -25,4 +26,13 @@
- google/gemma-2-27b-it
- google/gemma-2-9b-it
- mistralai/Mistral-7B-Instruct-v0.2
- mistralai/Mixtral-8x7B-Instruct-v0.1
- Pro/Qwen/Qwen2-7B-Instruct
- Pro/Qwen/Qwen2-1.5B-Instruct
- Pro/THUDM/glm-4-9b-chat
- Pro/THUDM/chatglm3-6b
- Pro/01-ai/Yi-1.5-9B-Chat-16K
- Pro/01-ai/Yi-1.5-6B-Chat
- Pro/internlm/internlm2_5-7b-chat
- Pro/meta-llama/Meta-Llama-3.1-8B-Instruct
- Pro/meta-llama/Meta-Llama-3-8B-Instruct
- Pro/google/gemma-2-9b-it

View File

@ -1,30 +0,0 @@
model: internlm/internlm2_5-20b-chat
label:
en_US: internlm/internlm2_5-20b-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '1'
output: '1'
unit: '0.000001'
currency: RMB

View File

@ -1,74 +0,0 @@
model: Qwen/Qwen2.5-Coder-7B-Instruct
label:
en_US: Qwen/Qwen2.5-Coder-7B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
zh_Hans: 重复惩罚
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing:
input: '0'
output: '0'
unit: '0.000001'
currency: RMB

View File

@ -1,74 +0,0 @@
model: Qwen/Qwen2.5-Math-72B-Instruct
label:
en_US: Qwen/Qwen2.5-Math-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
zh_Hans: 重复惩罚
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@ -1,7 +1,7 @@
# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models
model: qwen2.5-coder-7b-instruct
model: qwen2.5-7b-instruct
label:
en_US: qwen2.5-coder-7b-instruct
en_US: qwen2.5-7b-instruct
model_type: llm
features:
- agent-thought

View File

@ -1,21 +0,0 @@
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
15 l-97 0 0 -170z"/></g></svg>

Before

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>voyage</title>
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -1,4 +0,0 @@
model: rerank-1
model_type: rerank
model_properties:
context_size: 8000

View File

@ -1,4 +0,0 @@
model: rerank-lite-1
model_type: rerank
model_properties:
context_size: 4000

View File

@ -1,123 +0,0 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class VoyageRerankModel(RerankModel):
"""
Model class for Voyage rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
base_url = base_url.removesuffix("/")
try:
response = httpx.post(
base_url + "/rerank",
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["data"]:
rerank_document = RerankDocument(
index=result["index"],
text=result["document"],
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
)
return entity

View File

@ -1,172 +0,0 @@
import time
from json import JSONDecodeError, dumps
from typing import Optional
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
class VoyageTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Voyage text embedding model.
"""
api_base: str = "https://api.voyageai.com/v1"
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
api_key = credentials["api_key"]
if not api_key:
raise CredentialsValidateFailedError("api_key is required")
base_url = credentials.get("base_url", self.api_base)
base_url = base_url.removesuffix("/")
url = base_url + "/embeddings"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
voyage_input_type = "null"
if input_type is not None:
voyage_input_type = input_type.value
data = {"model": model, "input": texts, "input_type": voyage_input_type}
try:
response = requests.post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
msg = resp["detail"]
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
raise InvokeRateLimitError(msg)
elif response.status_code == 500:
raise InvokeServerUnavailableError(msg)
else:
raise InvokeBadRequestError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
result = TextEmbeddingResult(
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as e:
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -1,8 +0,0 @@
model: voyage-3-lite
model_type: text-embedding
model_properties:
context_size: 32000
pricing:
input: '0.00002'
unit: '0.001'
currency: USD

View File

@ -1,8 +0,0 @@
model: voyage-3
model_type: text-embedding
model_properties:
context_size: 32000
pricing:
input: '0.00006'
unit: '0.001'
currency: USD

View File

@ -1,28 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class VoyageProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
# Use `voyage-3` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -1,31 +0,0 @@
provider: voyage
label:
en_US: Voyage
description:
en_US: Embedding and Rerank Model Supported
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#EFFDFD"
help:
title:
en_US: Get your API key from Voyage AI
zh_Hans: 从 Voyage 获取 API Key
url:
en_US: https://dash.voyageai.com/
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -48,7 +48,7 @@ from ._utils import (
)
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")

View File

@ -248,7 +248,7 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i in range(len(args)):
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:

View File

@ -45,7 +45,7 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keywords_list = kwargs.get("keywords_list", None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:

View File

@ -10,6 +10,7 @@ from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -34,6 +35,9 @@ class RetrievalService:
weights: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
@ -108,6 +112,16 @@ class RetrievalService:
)
return all_documents
@classmethod
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model
)
return all_documents
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list

View File

@ -56,7 +56,7 @@ class TencentVector(BaseVector):
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return VectorType.TENCENT
return "tencent"
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}

View File

@ -0,0 +1,10 @@
from pydantic import BaseModel
class DocumentContext(BaseModel):
"""
Model class for document context.
"""
content: str
score: float

View File

@ -17,6 +17,8 @@ class Document(BaseModel):
"""
metadata: Optional[dict] = Field(default_factory=dict)
provider: Optional[str] = "dify"
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

View File

@ -28,11 +28,16 @@ class RerankModelRunner:
docs = []
doc_id = []
unique_documents = []
for document in documents:
dify_documents = [item for item in documents if item.provider == "dify"]
external_documents = [item for item in documents if item.provider == "external"]
for document in dify_documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
for document in external_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
@ -46,14 +51,10 @@ class RerankModelRunner:
# format document
rerank_document = Document(
page_content=result.text,
metadata={
"doc_id": documents[result.index].metadata["doc_id"],
"doc_hash": documents[result.index].metadata["doc_hash"],
"document_id": documents[result.index].metadata["document_id"],
"dataset_id": documents[result.index].metadata["dataset_id"],
"score": result.score,
},
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_documents

View File

@ -20,6 +20,7 @@ from core.ops.utils import measure_time
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
@ -30,6 +31,7 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -110,7 +112,7 @@ class DatasetRetrieval:
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0:
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
@ -146,69 +148,93 @@ class DatasetRetrieval:
message_id,
)
document_score_list = {}
for item in all_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = {
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": invoke_from.to_source(),
"score": item.metadata.get("score"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents
if dify_documents:
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if show_retrieve_source:
context_list = []
resource_number = 1
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, None),
}
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None),
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None),
)
)
if show_retrieve_source:
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, None),
}
if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
if hit_callback:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
def single_retrieve(
@ -256,36 +282,58 @@ class DatasetRetrieval:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = "keyword_search"
else:
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
reranking_model = (
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
)
# get score threshold
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
results = []
if dataset.provider == "external":
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
external_retrieval_parameters=dataset.retrieval_model,
)
for external_document in external_documents:
document = Document(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = "keyword_search"
else:
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
reranking_model = (
retrieval_model_config["reranking_model"]
if retrieval_model_config["reranking_enable"]
else None
)
# get score threshold
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
if results:
@ -356,7 +404,8 @@ class DatasetRetrieval:
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
) -> None:
"""Handle retrieval end."""
for document in documents:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
@ -409,35 +458,54 @@ class DatasetRetrieval:
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
if dataset.provider == "external":
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
for external_document in external_documents:
document = Document(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
all_documents.extend(documents)
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)
def to_dataset_retriever_tool(
self,

View File

@ -156,16 +156,34 @@ class KnowledgeRetrievalNode(BaseNode):
weights,
node_data.multiple_retrieval_config.reranking_enable,
)
context_list = []
if all_documents:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
},
"title": item.metadata.get("title"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
page_number_list = {}
for item in all_documents:
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
@ -186,13 +204,10 @@ class KnowledgeRetrievalNode(BaseNode):
Document.enabled == True,
Document.archived == False,
).first()
resource_number = 1
if dataset and document:
source = {
"metadata": {
"_source": "knowledge",
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
@ -212,9 +227,14 @@ class KnowledgeRetrievalNode(BaseNode):
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
context_list.append(source)
resource_number += 1
return context_list
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True)
position = 1
for item in retrieval_resource_list:
item["metadata"]["position"] = position
position += 1
return retrieval_resource_list
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -14,7 +14,7 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get("document_ids")
document_ids = kwargs.get("document_ids", None)
documents = []
start_at = time.perf_counter()
for document_id in document_ids:

View File

@ -38,9 +38,20 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
external_knowledge_info_fields = {
"external_knowledge_id": fields.String,
"external_knowledge_api_id": fields.String,
"external_knowledge_api_name": fields.String,
"external_knowledge_api_endpoint": fields.String,
}
dataset_detail_fields = {
"id": fields.String,
"name": fields.String,
@ -61,6 +72,8 @@ dataset_detail_fields = {
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
}
dataset_query_detail_fields = {

View File

@ -0,0 +1,11 @@
from flask_restful import fields
from libs.helper import TimestampField
external_knowledge_api_query_detail_fields = {
"id": fields.String,
"name": fields.String,
"setting": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
}

View File

@ -189,39 +189,23 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
class TokenManager:
@classmethod
def generate_token(
cls,
token_type: str,
account: Optional[Account] = None,
email: Optional[str] = None,
additional_data: dict = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
token_data = {"account_id": account.id, "email": account.email, "token_type": token_type}
if additional_data:
token_data.update(additional_data)
expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"]
token_key = cls._get_token_key(token, token_type)
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
if account_id:
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token
@classmethod
@ -250,12 +234,9 @@ class TokenManager:
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
):
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
key = cls._get_account_token_key(account_id, token_type)
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(key, expiry_time, token)
redis_client.setex(key, expiry_hours * 60 * 60, token)
@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:

View File

@ -1,6 +1,5 @@
import urllib.parse
from dataclasses import dataclass
from typing import Optional
import requests
@ -41,14 +40,12 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: Optional[str] = None):
def get_authorization_url(self):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
@ -93,15 +90,13 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: Optional[str] = None):
def get_authorization_url(self):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):

View File

@ -13,7 +13,7 @@ def valid_password(password):
if re.match(pattern, password) is not None:
return password
raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
raise ValueError("Not a valid password.")
def hash_password(password_str, salt_byte):

View File

@ -0,0 +1,48 @@
"""update-retrieval-resource
Revision ID: 6af6a521a53e
Revises: ec3df697ebbb
Create Date: 2024-09-24 09:22:43.570120
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -0,0 +1,73 @@
"""external_knowledge_api
Revision ID: 33f5fac87f29
Revises: 6af6a521a53e
Create Date: 2024-09-25 04:34:57.249436
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '33f5fac87f29'
down_revision = '6af6a521a53e'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('external_knowledge_apis',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('settings', sa.Text(), nullable=True),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey')
)
with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op:
batch_op.create_index('external_knowledge_apis_name_idx', ['name'], unique=False)
batch_op.create_index('external_knowledge_apis_tenant_idx', ['tenant_id'], unique=False)
op.create_table('external_knowledge_bindings',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('external_knowledge_id', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey')
)
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
batch_op.create_index('external_knowledge_bindings_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('external_knowledge_bindings_external_knowledge_api_idx', ['external_knowledge_api_id'], unique=False)
batch_op.create_index('external_knowledge_bindings_external_knowledge_idx', ['external_knowledge_id'], unique=False)
batch_op.create_index('external_knowledge_bindings_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
batch_op.drop_index('external_knowledge_bindings_tenant_idx')
batch_op.drop_index('external_knowledge_bindings_external_knowledge_idx')
batch_op.drop_index('external_knowledge_bindings_external_knowledge_api_idx')
batch_op.drop_index('external_knowledge_bindings_dataset_idx')
op.drop_table('external_knowledge_bindings')
with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op:
batch_op.drop_index('external_knowledge_apis_tenant_idx')
batch_op.drop_index('external_knowledge_apis_name_idx')
op.drop_table('external_knowledge_apis')
# ### end Alembic commands ###

View File

@ -1,4 +1,4 @@
"""add-dataset-retrival-model
"""add-dataset-retrieval-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e

View File

@ -38,6 +38,7 @@ class Dataset(db.Model):
)
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
@ -71,6 +72,14 @@ class Dataset(db.Model):
def index_struct_dict(self):
return json.loads(self.index_struct) if self.index_struct else None
@property
def external_retrieval_model(self):
default_retrieval_model = {
"top_k": 2,
"score_threshold": 0.0,
}
return self.retrieval_model or default_retrieval_model
@property
def created_by_account(self):
return db.session.get(Account, self.created_by)
@ -162,6 +171,29 @@ class Dataset(db.Model):
return tags or []
@property
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
)
if not external_knowledge_binding:
return None
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api:
return None
return {
"external_knowledge_id": external_knowledge_binding.external_knowledge_id,
"external_knowledge_api_id": external_knowledge_api.id,
"external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")
@ -687,3 +719,77 @@ class DatasetPermission(db.Model):
tenant_id = db.Column(StringUUID, nullable=False)
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class ExternalKnowledgeApis(db.Model):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
db.Index("external_knowledge_apis_name_idx", "name"),
)
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.String(255), nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"name": self.name,
"description": self.description,
"settings": self.settings_dict,
"dataset_bindings": self.dataset_bindings,
"created_by": self.created_by,
"created_at": self.created_at.isoformat(),
}
@property
def settings_dict(self):
try:
return json.loads(self.settings) if self.settings else None
except JSONDecodeError:
return None
@property
def dataset_bindings(self):
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
return dataset_bindings
class ExternalKnowledgeBindings(db.Model):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
external_knowledge_api_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
external_knowledge_id = db.Column(db.Text, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -1423,10 +1423,10 @@ class DatasetRetrieverResource(db.Model):
position = db.Column(db.Integer, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
dataset_name = db.Column(db.Text, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=True)
document_name = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.Text, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
data_source_type = db.Column(db.Text, nullable=True)
segment_id = db.Column(StringUUID, nullable=True)
score = db.Column(db.Float, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=True)

53
api/poetry.lock generated
View File

@ -6644,19 +6644,6 @@ files = [
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
]
[package.dependencies]
@ -8074,29 +8061,29 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.6.8"
version = "0.6.5"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.6.8-py3-none-linux_armv6l.whl", hash = "sha256:77944bca110ff0a43b768f05a529fecd0706aac7bcce36d7f1eeb4cbfca5f0f2"},
{file = "ruff-0.6.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27b87e1801e786cd6ede4ada3faa5e254ce774de835e6723fd94551464c56b8c"},
{file = "ruff-0.6.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd48f945da2a6334f1793d7f701725a76ba93bf3d73c36f6b21fb04d5338dcf5"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:677e03c00f37c66cea033274295a983c7c546edea5043d0c798833adf4cf4c6f"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9f1476236b3eacfacfc0f66aa9e6cd39f2a624cb73ea99189556015f27c0bdeb"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f5a2f17c7d32991169195d52a04c95b256378bbf0de8cb98478351eb70d526f"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5fd0d4b7b1457c49e435ee1e437900ced9b35cb8dc5178921dfb7d98d65a08d0"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8034b19b993e9601f2ddf2c517451e17a6ab5cdb1c13fdff50c1442a7171d87"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cfb227b932ba8ef6e56c9f875d987973cd5e35bc5d05f5abf045af78ad8e098"},
{file = "ruff-0.6.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef0411eccfc3909269fed47c61ffebdcb84a04504bafa6b6df9b85c27e813b0"},
{file = "ruff-0.6.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:007dee844738c3d2e6c24ab5bc7d43c99ba3e1943bd2d95d598582e9c1b27750"},
{file = "ruff-0.6.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ce60058d3cdd8490e5e5471ef086b3f1e90ab872b548814e35930e21d848c9ce"},
{file = "ruff-0.6.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1085c455d1b3fdb8021ad534379c60353b81ba079712bce7a900e834859182fa"},
{file = "ruff-0.6.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:70edf6a93b19481affd287d696d9e311388d808671bc209fb8907b46a8c3af44"},
{file = "ruff-0.6.8-py3-none-win32.whl", hash = "sha256:792213f7be25316f9b46b854df80a77e0da87ec66691e8f012f887b4a671ab5a"},
{file = "ruff-0.6.8-py3-none-win_amd64.whl", hash = "sha256:ec0517dc0f37cad14a5319ba7bba6e7e339d03fbf967a6d69b0907d61be7a263"},
{file = "ruff-0.6.8-py3-none-win_arm64.whl", hash = "sha256:8d3bb2e3fbb9875172119021a13eed38849e762499e3cfde9588e4b4d70968dc"},
{file = "ruff-0.6.8.tar.gz", hash = "sha256:a5bf44b1aa0adaf6d9d20f86162b34f7c593bfedabc51239953e446aefc8ce18"},
{file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"},
{file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"},
{file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"},
{file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"},
{file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"},
{file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"},
{file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"},
{file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"},
{file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"},
{file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"},
{file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"},
{file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"},
]
[[package]]
@ -10501,4 +10488,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "c4580c22e2b220c8c80dbc3f765060a09e14874ed29b690c13a533bf0365e789"
content-hash = "1f9d36b61528276a0761d87ef4f9fa787b5c1b49ae85b238c86626fa1110e2e8"

View File

@ -123,7 +123,6 @@ FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa"
[tool.poetry]
name = "dify-api"
@ -221,6 +220,7 @@ volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
oci = "^2.133.0"
tos = "^2.7.1"
nomic = "^3.1.2"
validators = "0.21.0"
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"
@ -287,4 +287,4 @@ optional = true
[tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0"
ruff = "~0.6.8"
ruff = "~0.6.5"

View File

@ -0,0 +1,92 @@
import datetime
import time
import click
from sqlalchemy import func
from werkzeug.exceptions import NotFound
import app
from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document
@app.celery.task(queue="dataset")
def clean_unused_message_task():
click.echo(click.style("Start clean unused messages .", fg="green"))
clean_days = int(dify_config.CLEAN_DAY_SETTING)
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
page = 1
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > thirty_days_ago,
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < thirty_days_ago,
)
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter
datasets = (
db.session.query(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < thirty_days_ago,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
if datasets.items is None or len(datasets.items) == 0:
break
page += 1
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# update document
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e:
click.echo(
click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
end_at = time.perf_counter()
click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green"))

View File

@ -1,6 +1,5 @@
import base64
import logging
import random
import secrets
import uuid
from datetime import datetime, timedelta, timezone
@ -23,9 +22,7 @@ from models.model import DifySetup
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
CannotOperateSelfError,
CurrentPasswordIncorrectError,
@ -33,21 +30,16 @@ from services.errors.account import (
LinkAccountIntegrateError,
MemberNotInTenantError,
NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
@staticmethod
def load_user(user_id: str) -> None | Account:
@ -93,34 +85,23 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: str = None) -> Account:
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountLoginError("Invalid email or password.")
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise AccountLoginError("Account is banned or closed.")
if password and invite_token and account.password is None:
# if invite_token is valid, set password and password_salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
db.session.commit()
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountLoginError("Invalid email or password.")
return account
@staticmethod
@ -146,18 +127,9 @@ class AccountService:
@staticmethod
def create_account(
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light"
) -> Account:
"""create account"""
if not dify_config.ALLOW_REGISTER and not is_setup:
from controllers.console.error import NotAllowedRegister
raise NotAllowedRegister()
account = Account()
account.email = email
account.name = name
@ -184,19 +156,6 @@ class AccountService:
db.session.commit()
return account
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
)
TenantService.create_owner_tenant_if_not_exist(account=account)
return account
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate"""
@ -255,9 +214,6 @@ class AccountService:
if ip_address:
AccountService.update_last_login(account, ip_address=ip_address)
exp = timedelta(days=30)
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
db.session.commit()
token = AccountService.get_account_jwt_token(account, exp=exp)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
return token
@ -273,26 +229,13 @@ class AccountService:
return AccountService.load_user(account_id)
@classmethod
def send_reset_password_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
account_email = account.email if account else email
def send_reset_password_email(cls, account):
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
token = TokenManager.generate_token(account, "reset_password")
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
return token
@classmethod
@ -303,112 +246,6 @@ class AccountService:
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "email_code_login")
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise Unauthorized("Account is banned or closed.")
return account
@staticmethod
def add_login_error_rate_limit(email: str) -> None:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, 60 * 60 * 24, count)
@staticmethod
def is_login_error_rate_limit(email: str) -> bool:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.LOGIN_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
def reset_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}"
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
hour_limit_key = f"email_send_ip_limit_hour:{ip_address}"
# check ip is frozen
if redis_client.get(freeze_key):
return True
# check current minute count
current_minute_count = redis_client.get(minute_key)
if current_minute_count is None:
current_minute_count = 0
current_minute_count = int(current_minute_count)
# check current hour count
if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE:
hour_limit_count = redis_client.get(hour_limit_key)
if hour_limit_count is None:
hour_limit_count = 0
hour_limit_count = int(hour_limit_count)
if hour_limit_count >= 1:
redis_client.setex(freeze_key, 60 * 60, 1)
return True
else:
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
# add hour limit count
redis_client.incr(hour_limit_key)
redis_client.expire(hour_limit_key, 60 * 60)
return True
redis_client.setex(minute_key, 60, current_minute_count + 1)
redis_client.expire(minute_key, 60)
return False
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"
@ -416,12 +253,8 @@ def _get_login_cache_key(*, account_id: str, token: str):
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False) -> Tenant:
def create_tenant(name: str) -> Tenant:
"""Create tenant"""
if not dify_config.ALLOW_CREATE_WORKSPACE and not is_setup:
from controllers.console.error import NotAllowedCreateWorkspace
raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name)
db.session.add(tenant)
@ -432,12 +265,8 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
def create_owner_tenant_if_not_exist(account: Account, name: Optional[str] = None):
"""Create owner tenant if not exist"""
if not dify_config.ALLOW_CREATE_WORKSPACE and not is_setup:
raise WorkSpaceNotAllowedCreateError()
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
)
@ -446,9 +275,9 @@ class TenantService:
return
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
tenant = TenantService.create_tenant(name)
else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
db.session.commit()
@ -462,13 +291,8 @@ class TenantService:
logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
ta.role = role
else:
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
db.session.add(ta)
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
db.session.add(ta)
db.session.commit()
return ta
@ -685,13 +509,12 @@ class RegisterService:
name=name,
interface_language=languages[0],
password=password,
is_setup=True,
)
account.last_login_ip = ip_address
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
TenantService.create_owner_tenant_if_not_exist(account)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup)
@ -728,16 +551,15 @@ class RegisterService:
if open_id is not None or provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if dify_config.ALLOW_CREATE_WORKSPACE:
if dify_config.EDITION != "SELF_HOSTED":
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
except Exception as e:
db.session.rollback()
logging.error(f"Register failed: {e}")
@ -796,11 +618,6 @@ class RegisterService:
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
return token
@classmethod
def is_valid_invite_token(cls, token: str) -> bool:
data = redis_client.get(cls._get_invitation_token_key(token))
return data is not None
@classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email:
@ -849,9 +666,7 @@ class RegisterService:
}
@classmethod
def _get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"

View File

@ -32,6 +32,7 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
ExternalKnowledgeBindings,
)
from models.model import UploadFile
from models.source import DataSourceOauthBinding
@ -39,6 +40,7 @@ from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
@ -56,10 +58,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by(
Dataset.created_at.desc()
)
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
@ -137,7 +137,14 @@ class DatasetService:
@staticmethod
def create_empty_dataset(
tenant_id: str, name: str, indexing_technique: Optional[str], account: Account, permission: Optional[str] = None
tenant_id: str,
name: str,
indexing_technique: Optional[str],
account: Account,
permission: Optional[str] = None,
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
@ -156,12 +163,28 @@ class DatasetService:
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
db.session.flush()
if provider == "external" and external_knowledge_api_id:
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if not external_knowledge_api:
raise ValueError("External API template not found.")
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,
external_knowledge_api_id=external_knowledge_api_id,
external_knowledge_id=external_knowledge_id,
created_by=account.id,
)
db.session.add(external_knowledge_binding)
db.session.commit()
return dataset
@staticmethod
def get_dataset(dataset_id):
def get_dataset(dataset_id) -> Dataset:
return Dataset.query.filter_by(id=dataset_id).first()
@staticmethod
@ -202,81 +225,103 @@ class DatasetService:
@staticmethod
def update_dataset(dataset_id, data, user):
data.pop("partial_member_list", None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
action = None
if dataset.indexing_technique != data["indexing_technique"]:
# if update indexing_technique
if data["indexing_technique"] == "economy":
action = "remove"
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
elif data["indexing_technique"] == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if dataset.provider == "external":
dataset.retrieval_model = data.get("external_retrieval_model", None)
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", "")
external_knowledge_id = data.get("external_knowledge_id", None)
db.session.add(dataset)
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first()
if (
data["embedding_model_provider"] != dataset.embedding_model_provider
or data["embedding_model"] != dataset.embedding_model
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
action = "update"
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
db.session.commit()
else:
data.pop("partial_member_list", None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
action = None
if dataset.indexing_technique != data["indexing_technique"]:
# if update indexing_technique
if data["indexing_technique"] == "economy":
action = "remove"
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
elif data["indexing_technique"] == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if (
data["embedding_model_provider"] != dataset.embedding_model_provider
or data["embedding_model"] != dataset.embedding_model
):
action = "update"
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now()
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now()
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data)
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod

View File

@ -0,0 +1,26 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel
class AuthorizationConfig(BaseModel):
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
class ProcessStatusSetting(BaseModel):
request_method: str
url: str
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None

View File

@ -13,10 +13,6 @@ class AccountLoginError(BaseServiceError):
pass
class AccountPasswordError(BaseServiceError):
pass
class AccountNotLinkTenantError(BaseServiceError):
pass

View File

@ -1,9 +0,0 @@
from services.errors.base import BaseServiceError
class WorkSpaceNotAllowedCreateError(BaseServiceError):
pass
class WorkSpaceNotFoundError(BaseServiceError):
pass

View File

@ -0,0 +1,378 @@
import json
import random
import time
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any, Optional, Union
import boto3
import httpx
import validators
# from tasks.external_document_indexing_task import external_document_indexing_task
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from models.dataset import (
Dataset,
Document,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
)
from models.model import UploadFile
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
ExternalKnowledgeApis.created_at.desc()
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
return external_knowledge_apis.items, external_knowledge_apis.total
@classmethod
def validate_api_list(cls, api_settings: dict):
if not api_settings:
raise ValueError("api list is empty")
if "endpoint" not in api_settings and not api_settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in api_settings and not api_settings["api_key"]:
raise ValueError("api_key is required")
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
ExternalDatasetService.check_endpoint_and_api_key(args.get("settings"))
external_knowledge_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=user_id,
updated_by=user_id,
name=args.get("name"),
description=args.get("description", ""),
settings=json.dumps(args.get("settings"), ensure_ascii=False),
)
db.session.add(external_knowledge_api)
db.session.commit()
return external_knowledge_api
@staticmethod
def check_endpoint_and_api_key(settings: dict):
if "endpoint" not in settings or not settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in settings or not settings["api_key"]:
raise ValueError("api_key is required")
endpoint = f"{settings['endpoint']}/retrieval"
api_key = settings["api_key"]
if not validators.url(endpoint):
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
if response.status_code == 404:
raise ValueError(f"Not Found: failed to connect to the endpoint: {endpoint}")
if response.status_code == 403:
raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}")
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first()
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
external_knowledge_api.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return external_knowledge_api
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
db.session.delete(external_knowledge_api)
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
for setting in settings:
custom_parameters = setting.get("document_process_setting")
if custom_parameters:
for parameter in custom_parameters:
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
raise ValueError(f'{parameter.get("name")} is required')
@staticmethod
def init_external_dataset(tenant_id: str, user_id: str, args: dict, created_from: str = "web"):
external_knowledge_api_id = args.get("external_knowledge_api_id")
data_source = args.get("data_source")
if data_source is None:
raise ValueError("data source is required")
process_parameter = args.get("process_parameter")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
dataset = Dataset(
tenant_id=tenant_id,
name=args.get("name"),
description=args.get("description", ""),
provider="external",
created_by=user_id,
)
db.session.add(dataset)
db.session.flush()
document = Document.query.filter_by(dataset_id=dataset.id).order_by(Document.position.desc()).first()
position = document.position + 1 if document else 1
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
document_ids = []
if data_source["type"] == "upload_file":
upload_file_list = data_source["info_list"]["file_info_list"]["file_ids"]
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
if file:
data_source_info = {
"upload_file_id": file_id,
}
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=data_source["type"],
data_source_info=json.dumps(data_source_info),
batch=batch,
name=file.name,
created_from=created_from,
created_by=user_id,
)
position += 1
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
db.session.commit()
# external_document_indexing_task.delay(dataset.id, external_knowledge_api_id, data_source, process_parameter)
return dataset
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
"""
do http request depending on api bundle
"""
kwargs = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,
}
response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
else:
headers = {}
if authorization.type == "api-key":
if authorization.config is None:
raise ValueError("authorization config is required")
if authorization.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.parse_obj(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
dataset = Dataset(
tenant_id=tenant_id,
name=args.get("name"),
description=args.get("description", ""),
provider="external",
retrieval_model=args.get("external_retrieval_model"),
created_by=user_id,
)
db.session.add(dataset)
db.session.flush()
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,
external_knowledge_api_id=args.get("external_knowledge_api_id"),
external_knowledge_id=args.get("external_knowledge_id"),
created_by=user_id,
)
db.session.add(external_knowledge_binding)
db.session.commit()
return dataset
@staticmethod
def fetch_external_knowledge_retrieval(
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_binding.external_knowledge_api_id
).first()
if not external_knowledge_api:
raise ValueError("external api template not found")
settings = json.loads(external_knowledge_api.settings)
headers = {"Content-Type": "application/json"}
if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
score_threshold_enabled = external_retrieval_parameters.get("score_threshold_enabled") or False
score_threshold = external_retrieval_parameters.get("score_threshold", 0.0) if score_threshold_enabled else 0.0
request_params = {
"retrieval_setting": {
"top_k": external_retrieval_parameters.get("top_k"),
"score_threshold": score_threshold,
},
"query": query,
"knowledge_id": external_knowledge_binding.external_knowledge_id,
}
external_knowledge_api_setting = {
"url": f"{settings.get('endpoint')}/retrieval",
"request_method": "post",
"headers": headers,
"params": request_params,
}
response = ExternalDatasetService.process_external_api(
ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None
)
if response.status_code == 200:
return response.json().get("records", [])
return []
@staticmethod
def test_external_knowledge_retrieval(retrieval_setting: dict, query: str, external_knowledge_id: str):
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
region_name="us-east-1",
)
response = client.retrieve(
knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}

View File

@ -42,11 +42,6 @@ class SystemFeatureModel(BaseModel):
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
is_allow_register: bool = True
is_allow_create_workspace: bool = True
class FeatureService:
@ -65,22 +60,12 @@ class FeatureService:
def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel()
cls.__fulfill_login_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
cls._fulfill_params_from_enterprise(system_features)
return system_features
@classmethod
def __fulfill_login_params_from_env(cls, features: FeatureModel):
features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
features.is_allow_register = dify_config.ALLOW_REGISTER
features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
@ -132,5 +117,3 @@ class FeatureService:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
features.enable_email_password_login = enterprise_info["enable_email_password_login"]

View File

@ -19,7 +19,15 @@ default_retrieval_model = {
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
def retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
@ -62,10 +70,44 @@ class HitTestingService:
return cls.compact_retrieve_response(dataset, query, all_documents)
@classmethod
def external_retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
) -> dict:
if dataset.provider != "external":
return {
"query": {"content": query},
"records": [],
}
start = time.perf_counter()
all_documents = RetrievalService.external_retrieve(
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()
logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_external_retrieve_response(dataset, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
i = 0
records = []
for document in documents:
index_node_id = document.metadata["doc_id"]
@ -81,7 +123,6 @@ class HitTestingService:
)
if not segment:
i += 1
continue
record = {
@ -91,8 +132,6 @@ class HitTestingService:
records.append(record)
i += 1
return {
"query": {
"content": query,
@ -100,6 +139,25 @@ class HitTestingService:
"records": records,
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
records = []
if dataset.provider == "external":
for document in documents:
record = {
"content": document.get("content", None),
"title": document.get("title", None),
"score": document.get("score", None),
"metadata": document.get("metadata", None),
}
records.append(record)
return {
"query": {
"content": query,
},
"records": records,
}
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]

View File

@ -0,0 +1,93 @@
import json
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import DocumentIsPausedException
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, ExternalKnowledgeApis
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
@shared_task(queue="dataset")
def external_document_indexing_task(
dataset_id: str, external_knowledge_api_id: str, data_source: dict, process_parameter: dict
):
"""
Async process document
:param dataset_id:
:param external_knowledge_api_id:
:param data_source:
:param process_parameter:
Usage: external_document_indexing_task.delay(dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(
click.style("Processed external dataset: {} failed, dataset not exit.".format(dataset_id), fg="red")
)
return
# get external api template
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter(
ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == dataset.tenant_id
)
.first()
)
if not external_knowledge_api:
logging.info(
click.style(
"Processed external dataset: {} failed, api template: {} not exit.".format(
dataset_id, external_knowledge_api_id
),
fg="red",
)
)
return
files = {}
if data_source["type"] == "upload_file":
upload_file_list = data_source["info_list"]["file_info_list"]["file_ids"]
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
if file:
files[file.id] = (file.name, storage.load_once(file.key), file.mime_type)
try:
settings = ExternalDatasetService.get_external_knowledge_api_settings(
json.loads(external_knowledge_api.settings)
)
# assemble headers
headers = ExternalDatasetService.assembling_headers(settings.authorization, settings.headers)
# do http request
response = ExternalDatasetService.process_external_api(settings, headers, process_parameter, files)
job_id = response.json().get("job_id")
if job_id:
# save job_id to dataset
dataset.job_id = job_id
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Processed external dataset: {} successful, latency: {}".format(dataset.id, end_at - start_at),
fg="green",
)
)
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -1,41 +0,0 @@
import logging
import time
import click
from celery import shared_task
from flask import render_template
from extensions.ext_mail import mail
@shared_task(queue="mail")
def send_email_code_login_mail_task(language: str, to: str, code: str):
"""
Async Send email code login mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address
:param code: Email code to be included in the email
"""
if not mail.is_inited():
return
logging.info(click.style("Start email code login mail to {}".format(to), fg="green"))
start_at = time.perf_counter()
# send email code login mail using different languages
try:
if language == "zh-Hans":
html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code)
mail.send(to=to, subject="邮箱验证码", html=html_content)
else:
html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Email Code", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style(
"Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
)
except Exception:
logging.exception("Send email code login mail to {} failed".format(to))

View File

@ -5,16 +5,17 @@ import click
from celery import shared_task
from flask import render_template
from configs import dify_config
from extensions.ext_mail import mail
@shared_task(queue="mail")
def send_reset_password_mail_task(language: str, to: str, code: str):
def send_reset_password_mail_task(language: str, to: str, token: str):
"""
Async Send reset password mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address
:param code: Reset password code
:param token: Reset password token to be included in the email
"""
if not mail.is_inited():
return
@ -24,11 +25,12 @@ def send_reset_password_mail_task(language: str, to: str, code: str):
# send reset password mail using different languages
try:
url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}"
if language == "zh-Hans":
html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code)
html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url)
mail.send(to=to, subject="重置您的 Dify 密码", html=html_content)
else:
html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code)
html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url)
mail.send(to=to, subject="Reset Your Dify Password", html=html_content)
end_at = time.perf_counter()

View File

@ -1,74 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
height: 360px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.code-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #f2f4f7;
margin: 16px auto;
}
.code {
line-height: 36px;
font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<!-- Optional: Add a logo or a header image here -->
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
</div>
<p class="title">Your login code for Dify</p>
<p class="description">Copy and paste this code, this code will only be valid for the next 5 minutes.</p>
<div class="code-content">
<span class="code">{{code}}</span>
</div>
<p class="tips">If you didn't request a login, don't worry. You can safely ignore this email.</p>
</div>
</body>
</html>

View File

@ -1,74 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
height: 360px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.code-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #f2f4f7;
margin: 16px auto;
}
.code {
line-height: 36px;
font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<!-- Optional: Add a logo or a header image here -->
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
</div>
<p class="title">Dify 的登录验证码</p>
<p class="description">复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。</p>
<div class="code-content">
<span class="code">{{code}}</span>
</div>
<p class="tips">如果您没有请求登陆,请不要担心。您可以安全地忽略此电子邮件。</p>
</div>
</body>
</html>

View File

@ -59,7 +59,7 @@
<div class="content">
<p>Dear {{ to }},</p>
<p>{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
<p>Click the button below to log in to Dify and join the workspace.</p>
<p>You can now log in to Dify using the GitHub or Google account associated with this email.</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
</div>
<div class="footer">

View File

@ -59,7 +59,7 @@
<div class="content">
<p>尊敬的 {{ to }}</p>
<p>{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p>点击下方按钮即可登录 Dify 并且加入空间</p>
<p>您现在可以使用与此邮件相对应的 GitHub 或 Google 账号登录 Dify</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
</div>
<div class="footer">

Some files were not shown because too many files have changed in this diff Show More