Merge branch 'main' into feat/memory-orchestration-be

# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/prompt/entities/advanced_prompt_entities.py
#	api/core/variables/segments.py
This commit is contained in:
Stream
2025-09-15 14:14:56 +08:00
2025 changed files with 67244 additions and 18565 deletions

View File

@ -5,7 +5,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional, cast
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
@ -37,7 +37,6 @@ from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
@ -65,7 +64,13 @@ from tasks.mail_owner_transfer_task import (
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
from tasks.mail_reset_password_task import send_reset_password_mail_task
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
from tasks.mail_reset_password_task import (
send_reset_password_mail_task,
send_reset_password_mail_task_when_account_not_exist,
)
logger = logging.getLogger(__name__)
class TokenPair(BaseModel):
@ -80,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
@ -93,6 +99,7 @@ class AccountService:
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@ -103,14 +110,14 @@ class AccountService:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
def _store_refresh_token(refresh_token: str, account_id: str):
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
def _delete_refresh_token(refresh_token: str, account_id: str):
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@ -143,8 +150,11 @@ class AccountService:
if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
# NOTE: make sure account is accessible outside of a db session
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
db.session.refresh(account)
db.session.close()
return account
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@ -161,12 +171,12 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
@ -189,7 +199,7 @@ class AccountService:
db.session.commit()
return cast(Account, account)
return account
@staticmethod
def update_account_password(account, password, new_password):
@ -209,6 +219,7 @@ class AccountService:
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.add(account)
db.session.commit()
return account
@ -217,9 +228,9 @@ class AccountService:
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
password: str | None = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
is_setup: bool | None = False,
) -> Account:
"""create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@ -240,6 +251,8 @@ class AccountService:
account.name = name
if password:
valid_password(password)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
@ -263,7 +276,7 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
email: str, name: str, interface_language: str, password: str | None = None
) -> Account:
"""create account"""
account = AccountService.create_account(
@ -288,7 +301,9 @@ class AccountService:
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
raise EmailCodeAccountDeletionRateLimitExceededError(
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
)
send_account_deletion_verification_code.delay(to=email, code=code)
@ -306,16 +321,16 @@ class AccountService:
return True
@staticmethod
def delete_account(account: Account) -> None:
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
def link_account_integrate(provider: str, open_id: str, account: Account):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = (
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
@ -332,13 +347,13 @@ class AccountService:
db.session.add(account_integrate)
db.session.commit()
logging.info("Account %s linked %s account %s.", account.id, provider, open_id)
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
except Exception as e:
logging.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account) -> None:
def close_account(account: Account):
"""Close account"""
account.status = AccountStatus.CLOSED.value
db.session.commit()
@ -346,6 +361,7 @@ class AccountService:
@staticmethod
def update_account(account, **kwargs):
"""Update account fields"""
account = db.session.merge(account)
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
@ -367,7 +383,7 @@ class AccountService:
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
@ -375,7 +391,7 @@ class AccountService:
db.session.commit()
@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
@ -391,7 +407,7 @@ class AccountService:
return TokenPair(access_token=access_token, refresh_token=refresh_token)
@staticmethod
def logout(*, account: Account) -> None:
def logout(*, account: Account):
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@ -423,9 +439,10 @@ class AccountService:
@classmethod
def send_reset_password_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
is_allow_register: bool = False,
):
account_email = account.email if account else email
if account_email is None:
@ -434,26 +451,67 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
code, token = cls.generate_reset_password_token(account_email, account)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
if account:
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
else:
send_reset_password_mail_task_when_account_not_exist.delay(
language=language,
to=account_email,
is_allow_register=is_allow_register,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_register_email(
cls,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.email_register_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
code, token = cls.generate_email_register_token(account_email)
if account:
send_email_register_mail_task_when_account_exist.delay(
language=language,
to=account_email,
account_name=account.name,
)
else:
send_email_register_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.email_register_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_change_email_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
old_email: str | None = None,
language: str = "en-US",
phase: Optional[str] = None,
phase: str | None = None,
):
account_email = account.email if account else email
if account_email is None:
@ -464,7 +522,7 @@ class AccountService:
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError()
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
@ -480,8 +538,8 @@ class AccountService:
@classmethod
def send_change_email_completed_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
@ -496,10 +554,10 @@ class AccountService:
@classmethod
def send_owner_transfer_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -508,7 +566,7 @@ class AccountService:
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
raise OwnerTransferRateLimitExceededError()
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
code, token = cls.generate_owner_transfer_token(account_email, account)
workspace_name = workspace_name or ""
@ -525,10 +583,10 @@ class AccountService:
@classmethod
def send_old_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
new_owner_email: str = "",
):
account_email = account.email if account else email
@ -546,10 +604,10 @@ class AccountService:
@classmethod
def send_new_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -566,8 +624,8 @@ class AccountService:
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -578,13 +636,26 @@ class AccountService:
)
return code, token
@classmethod
def generate_email_register_token(
cls,
email: str,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
return code, token
@classmethod
def generate_change_email_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
old_email: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -600,8 +671,8 @@ class AccountService:
def generate_owner_transfer_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -616,6 +687,10 @@ class AccountService:
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@classmethod
def revoke_email_register_token(cls, token: str):
TokenManager.revoke_token(token, "email_register")
@classmethod
def revoke_change_email_token(cls, token: str):
TokenManager.revoke_token(token, "change_email")
@ -625,22 +700,26 @@ class AccountService:
TokenManager.revoke_token(token, "owner_transfer")
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_reset_password_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_register_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "change_email")
@classmethod
def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "owner_transfer")
@classmethod
def send_email_code_login_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
email = account.email if account else email
@ -649,7 +728,7 @@ class AccountService:
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
@ -664,7 +743,7 @@ class AccountService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -698,7 +777,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
def add_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -727,7 +806,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
def add_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -735,6 +814,16 @@ class AccountService:
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=None)
def add_email_register_error_rate_limit(email: str) -> None:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
@ -754,9 +843,27 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_register_error_rate_limit(email: str) -> bool:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str) -> None:
def reset_email_register_error_rate_limit(email: str):
key = f"email_register_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str):
key = f"change_email_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -784,7 +891,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_owner_transfer_error_rate_limit(email: str) -> None:
def add_owner_transfer_error_rate_limit(email: str):
key = f"owner_transfer_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -858,7 +965,7 @@ class AccountService:
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -889,9 +996,7 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
@ -925,7 +1030,7 @@ class TenantService:
"""Create tenant member"""
if role == TenantAccountRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
logging.error("Tenant %s has already an owner.", tenant.id)
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
@ -963,7 +1068,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
def switch_tenant(account: Account, tenant_id: str | None = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@ -1045,7 +1150,7 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
@ -1060,7 +1165,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
"""Check member permission"""
perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -1080,7 +1185,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
"""Remove member from tenant"""
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
@ -1095,7 +1200,7 @@ class TenantService:
db.session.commit()
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
@ -1122,10 +1227,10 @@ class TenantService:
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
def get_custom_config(tenant_id: str):
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:
@ -1143,7 +1248,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
def setup(cls, email: str, name: str, password: str, ip_address: str):
"""
Setup dify
@ -1177,7 +1282,7 @@ class RegisterService:
db.session.query(Tenant).delete()
db.session.commit()
logging.exception("Setup account failed, email: %s, name: %s", email, name)
logger.exception("Setup account failed, email: %s, name: %s", email, name)
raise ValueError(f"Setup failed: {e}")
@classmethod
@ -1185,13 +1290,13 @@ class RegisterService:
cls,
email,
name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
create_workspace_required: Optional[bool] = True,
password: str | None = None,
open_id: str | None = None,
provider: str | None = None,
language: str | None = None,
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
) -> Account:
db.session.begin_nested()
"""Register account"""
@ -1222,15 +1327,15 @@ class RegisterService:
db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logging.exception("Register failed")
logger.exception("Register failed")
raise AccountRegisterError("Workspace is not allowed to create.")
except AccountRegisterError as are:
db.session.rollback()
logging.exception("Register failed")
logger.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
logger.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
return account
@ -1308,10 +1413,8 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@ -1348,9 +1451,9 @@ class RegisterService:
}
@classmethod
def _get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
def get_invitation_by_token(
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"

View File

@ -17,7 +17,7 @@ from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
def get_prompt(cls, args: dict):
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@ -29,17 +29,17 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@ -70,10 +70,10 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),

View File

@ -1,8 +1,7 @@
import threading
from typing import Optional
from typing import Any
import pytz
from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@ -10,13 +9,14 @@ from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
"""
Service to get agent logs
"""
@ -35,7 +35,7 @@ class AgentService:
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Optional[Message] = (
message: Message | None = (
db.session.query(Message)
.where(
Message.id == message_id,
@ -61,14 +61,15 @@ class AgentService:
executor = executor.name
else:
executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("App model config not found")
result = {
result: dict[str, Any] = {
"meta": {
"status": "success",
"executor": executor,

View File

@ -1,8 +1,6 @@
import uuid
from typing import cast
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -10,6 +8,8 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,6 +24,7 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -40,7 +41,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
annotation: MessageAnnotation | None = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@ -62,6 +63,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@ -70,10 +72,10 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
def enable_app_annotation(cls, args: dict, app_id: str):
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -84,6 +86,8 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay(
str(job_id),
app_id,
@ -96,7 +100,9 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@ -113,6 +119,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -145,6 +153,8 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -164,6 +174,8 @@ class AppAnnotationService:
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -193,6 +205,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -230,6 +244,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -246,11 +262,9 @@ class AppAnnotationService:
db.session.delete(annotation)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
@ -269,6 +283,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -282,7 +298,7 @@ class AppAnnotationService:
annotations_to_delete = (
db.session.query(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.filter(MessageAnnotation.id.in_(annotation_ids))
.where(MessageAnnotation.id.in_(annotation_ids))
.all()
)
@ -315,8 +331,10 @@ class AppAnnotationService:
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -328,9 +346,9 @@ class AppAnnotationService:
try:
# Skip the first row
df = pd.read_csv(file, dtype=str)
df = pd.read_csv(file.stream, dtype=str)
result = []
for index, row in df.iterrows():
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
result.append(content)
if len(result) == 0:
@ -355,6 +373,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -425,6 +445,8 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -438,19 +460,29 @@ class AppAnnotationService:
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -479,21 +511,31 @@ class AppAnnotationService:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
@classmethod
def clear_all_annotations(cls, app_id: str) -> dict:
def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)

View File

@ -30,7 +30,7 @@ class APIBasedExtensionService:
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
@ -51,7 +51,7 @@ class APIBasedExtensionService:
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
def _validation(cls, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
@ -95,7 +95,7 @@ class APIBasedExtensionService:
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
def _ping_connection(extension_data: APIBasedExtension):
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})

View File

@ -4,7 +4,6 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
@ -17,6 +16,7 @@ from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
@ -42,7 +42,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.3.1"
CURRENT_DSL_VERSION = "0.4.0"
class ImportMode(StrEnum):
@ -60,8 +60,8 @@ class ImportStatus(StrEnum):
class Import(BaseModel):
id: str
status: ImportStatus
app_id: Optional[str] = None
app_mode: Optional[str] = None
app_id: str | None = None
app_mode: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
@ -98,17 +98,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
class PendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
app_id: str | None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None
app_id: str | None = None
class AppDslService:
@ -120,14 +120,14 @@ class AppDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
app_id: Optional[str] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
app_id: str | None = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@ -406,15 +406,15 @@ class AppDslService:
def _create_or_update_app(
self,
*,
app: Optional[App],
app: App | None,
data: dict,
account: Account,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
@ -532,7 +532,7 @@ class AppDslService:
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
"""
Export app
:param app_model: App instance
@ -556,7 +556,7 @@ class AppDslService:
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
cls._append_workflow_export_data(
export_data=export_data, app_model=app_model, include_secret=include_secret
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
)
else:
cls._append_model_config_export_data(export_data, app_model)
@ -564,14 +564,16 @@ class AppDslService:
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data
:param export_data: export data
:param app_model: App instance
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model)
workflow = workflow_service.get_draft_workflow(app_model, workflow_id)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
@ -606,7 +608,7 @@ class AppDslService:
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
"""
Append model config export data
:param export_data: export data
@ -784,7 +786,10 @@ class AppDslService:
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
"""Encrypt dataset_id using AES-CBC mode or return plain text based on configuration"""
if not dify_config.DSL_EXPORT_ENCRYPT_DATASET_ID:
return dataset_id
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
@ -793,12 +798,34 @@ class AppDslService:
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
"""AES decryption with fallback to plain text UUID"""
# First, check if it's already a plain UUID (not encrypted)
if cls._is_valid_uuid(encrypted_data):
return encrypted_data
# If it's not a UUID, try to decrypt it
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
decrypted_text = pt.decode()
# Validate that the decrypted result is a valid UUID
if cls._is_valid_uuid(decrypted_text):
return decrypted_text
else:
# If decrypted result is not a valid UUID, it's probably not our encrypted data
return None
except Exception:
# If decryption fails completely, return None
return None
@staticmethod
def _is_valid_uuid(value: str) -> bool:
"""Check if string is a valid UUID format"""
try:
uuid.UUID(value)
return True
except (ValueError, TypeError):
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Union
from openai._exceptions import RateLimitError
@ -55,12 +55,12 @@ class AppGenerateService:
cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id)
# app level rate limiter
max_active_request = AppGenerateService._get_max_active_requests(app_model)
max_active_request = cls._get_max_active_requests(app_model)
rate_limit = RateLimit(app_model.id, max_active_request)
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
@ -69,7 +69,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
@ -78,7 +78,7 @@ class AppGenerateService:
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
@ -87,7 +87,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -103,7 +103,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -155,14 +155,14 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
@ -174,14 +174,14 @@ class AppGenerateService:
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
@ -214,7 +214,7 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
"""
Get workflow
:param app_model: app model
@ -227,7 +227,7 @@ class AppGenerateService:
# If workflow_id is specified, get the specific workflow version
if workflow_id:
try:
workflow_uuid = uuid.UUID(workflow_id)
_ = uuid.UUID(workflow_id)
except ValueError:
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)

View File

@ -6,7 +6,7 @@ from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@ -1,8 +1,7 @@
import json
import logging
from typing import Optional, TypedDict, cast
from typing import TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@ -25,6 +25,8 @@ from services.feature_service import FeatureService
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
logger = logging.getLogger(__name__)
class AppService:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
@ -38,15 +40,15 @@ class AppService:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW.value)
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION.value)
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT.value)
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@ -94,8 +96,8 @@ class AppService:
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
except Exception as e:
logging.exception("Get default model instance failed, tenant_id: %s", tenant_id)
except Exception:
logger.exception("Get default model instance failed, tenant_id: %s", tenant_id)
model_instance = None
if model_instance:
@ -166,9 +168,13 @@ class AppService:
"""
Get App
"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []:
@ -199,11 +205,12 @@ class AppService:
# override tool parameters
tool["tool_parameters"] = masked_parameter
except Exception as e:
except Exception:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
if model_config:
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
@ -237,6 +244,7 @@ class AppService:
:param args: request args
:return: App instance
"""
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
@ -257,6 +265,7 @@ class AppService:
:param name: new name
:return: App instance
"""
assert current_user is not None
app.name = name
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -272,6 +281,7 @@ class AppService:
:param icon_background: new icon_background
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
@ -289,7 +299,7 @@ class AppService:
"""
if enable_site == app.enable_site:
return app
assert current_user is not None
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -306,6 +316,7 @@ class AppService:
"""
if enable_api == app.enable_api:
return app
assert current_user is not None
app.enable_api = enable_api
app.updated_by = current_user.id
@ -314,7 +325,7 @@ class AppService:
return app
def delete_app(self, app: App) -> None:
def delete_app(self, app: App):
"""
Delete app
:param app: App instance
@ -329,7 +340,7 @@ class AppService:
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
def get_app_meta(self, app_model: App):
"""
Get app meta info
:param app_model: app model
@ -359,7 +370,7 @@ class AppService:
}
)
else:
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
app_model_config: AppModelConfig | None = app_model.app_model_config
if not app_model_config:
return meta
@ -382,7 +393,7 @@ class AppService:
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:

View File

@ -2,7 +2,6 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import Optional
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@ -12,7 +11,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.model import App, AppMode, Message
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
@ -40,7 +39,9 @@ class AudioService:
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")
@ -75,18 +76,18 @@ class AudioService:
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:

View File

@ -1,5 +1,7 @@
import json
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
@ -8,12 +10,12 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
@staticmethod

View File

@ -1,5 +1,5 @@
import os
from typing import Literal, Optional
from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -70,10 +70,10 @@ class BillingService:
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user):
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -34,7 +35,7 @@ logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
@ -62,7 +63,7 @@ class ClearFreePlanTenantExpiredLogs:
# Query records related to expired messages
records = (
session.query(model)
.filter(
.where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
@ -101,7 +102,7 @@ class ClearFreePlanTenantExpiredLogs:
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).filter(
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
@ -295,7 +296,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
.filter(
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@ -321,9 +322,9 @@ class ClearFreePlanTenantExpiredLogs:
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).filter(
WorkflowAppLog.id.in_(workflow_app_log_ids),
).delete(synchronize_session=False)
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
)
session.commit()
click.echo(
@ -353,7 +354,7 @@ class ClearFreePlanTenantExpiredLogs:
thread_pool = ThreadPoolExecutor(max_workers=10)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
try:
if (
not dify_config.BILLING_ENABLED
@ -407,6 +408,7 @@ class ClearFreePlanTenantExpiredLogs:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

View File

@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
def get_code_based_extension(module: str):
module_extensions = code_based_extension.module_extensions(module)
return [
{

View File

@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
@ -36,12 +36,12 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
include_ids: Sequence[str] | None = None,
exclude_ids: Sequence[str] | None = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -118,7 +118,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
name: str,
auto_generate: bool,
):
@ -158,7 +158,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
.where(
@ -178,7 +178,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
try:
logger.info(
"Initiating conversation deletion for app_name %s, conversation_id: %s",
@ -200,9 +200,9 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
limit: int,
last_id: Optional[str],
last_id: str | None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@ -222,8 +222,8 @@ class ConversationService:
# Filter for variables created after the last_id
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
# Apply limit to query
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
# Apply limit to query: fetch one extra row to determine has_more
query_stmt = stmt.limit(limit + 1)
rows = session.scalars(query_stmt).all()
has_more = False
@ -248,9 +248,9 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
new_value: Any,
) -> dict:
):
"""
Update a conversation variable's value.

View File

@ -6,10 +6,11 @@ import secrets
import time
import uuid
from collections import Counter
from typing import Any, Literal, Optional
from collections.abc import Sequence
from typing import Any, Literal
from flask_login import current_user
from sqlalchemy import func, select
import sqlalchemy as sa
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -27,6 +28,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -76,6 +78,8 @@ from tasks.remove_document_from_index_task import remove_document_from_index_tas
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
logger = logging.getLogger(__name__)
class DatasetService:
@staticmethod
@ -131,7 +135,14 @@ class DatasetService:
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if tenant_id is not None:
target_ids = TagService.get_target_ids_by_tag_ids(
"knowledge",
tenant_id,
tag_ids,
)
else:
target_ids = []
if target_ids and len(target_ids) > 0:
query = query.where(Dataset.id.in_(target_ids))
else:
@ -174,16 +185,16 @@ class DatasetService:
def create_empty_dataset(
tenant_id: str,
name: str,
description: Optional[str],
indexing_technique: Optional[str],
description: str | None,
indexing_technique: str | None,
account: Account,
permission: Optional[str] = None,
permission: str | None = None,
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
embedding_model_provider: Optional[str] = None,
embedding_model_name: Optional[str] = None,
retrieval_model: Optional[RetrievalModel] = None,
external_knowledge_api_id: str | None = None,
external_knowledge_id: str | None = None,
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
@ -210,7 +221,7 @@ class DatasetService:
and retrieval_model.reranking_model.reranking_model_name
):
# check if reranking model setting is valid
DatasetService.check_embedding_model_setting(
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
@ -246,8 +257,8 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -492,8 +503,11 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try:
model_manager = ModelManager()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@ -605,8 +619,12 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager()
try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@ -615,7 +633,7 @@ class DatasetService:
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, preserve existing settings
logging.warning(
logger.warning(
"Failed to initialize embedding model %s/%s, preserving existing settings",
data["embedding_model_provider"],
data["embedding_model"],
@ -653,19 +671,17 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
return db.session.execute(stmt).scalar_one()
@staticmethod
def check_dataset_permission(dataset, user):
if dataset.tenant_id != user.current_tenant_id:
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
# For partial team permission, user needs explicit permission or be the creator
@ -674,11 +690,11 @@ class DatasetService:
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if not user_permission:
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
if not dataset:
raise ValueError("Dataset not found")
@ -715,7 +731,9 @@ class DatasetService:
)
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
@ -724,14 +742,12 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.where(
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
).all()
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -852,7 +868,7 @@ class DocumentService:
}
@staticmethod
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
if document_id:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
@ -862,73 +878,64 @@ class DocumentService:
return None
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
return document
@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)
).all()
return documents
@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
assert isinstance(current_user, Account)
documents = db.session.scalars(
select(Document).where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()
return documents
@ -965,13 +972,14 @@ class DocumentService:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
if document.data_source_type == "upload_file"
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
for document in documents:
db.session.delete(document)
@ -979,6 +987,8 @@ class DocumentService:
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found.")
@ -994,7 +1004,7 @@ class DocumentService:
if dataset.built_in_field_enabled:
if document.doc_metadata:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = name
doc_metadata[BuiltInField.document_name] = name
document.doc_metadata = doc_metadata
document.name = name
@ -1008,6 +1018,7 @@ class DocumentService:
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
@ -1063,8 +1074,9 @@ class DocumentService:
# sync document indexing
document.indexing_status = "waiting"
data_source_info = document.data_source_info_dict
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
if data_source_info:
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
@ -1087,12 +1099,15 @@ class DocumentService:
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
):
) -> tuple[list[Document], str]:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@ -1149,7 +1164,7 @@ class DocumentService:
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -1190,11 +1205,11 @@ class DocumentService:
created_by=account.id,
)
else:
logging.warning(
logger.warning(
"Invalid process rule mode: %s, can not find dataset process rule",
process_rule.mode,
)
return
return [], ""
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
@ -1429,6 +1444,8 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
@ -1446,9 +1463,11 @@ class DocumentService:
dataset: Dataset,
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
if document is None:
@ -1508,7 +1527,7 @@ class DocumentService:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
@ -1569,6 +1588,9 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@ -1612,7 +1634,7 @@ class DocumentService:
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=2,
top_k=4,
score_threshold_enabled=False,
)
# save dataset
@ -1882,7 +1904,7 @@ class DocumentService:
task_func.delay(*task_args)
except Exception as e:
# Log the error but do not rollback the transaction
logging.exception("Error executing async task for document %s", update_info["document"].id)
logger.exception("Error executing async task for document %s", update_info["document"].id)
# don't raise the error immediately, but capture it for later
propagation_error = e
try:
@ -1893,7 +1915,7 @@ class DocumentService:
redis_client.setex(indexing_cache_key, 600, 1)
except Exception as e:
# Log the error but do not rollback the transaction
logging.exception("Error setting cache for document %s", update_info["document"].id)
logger.exception("Error setting cache for document %s", update_info["document"].id)
# Raise any propagation error after all updates
if propagation_error:
raise propagation_error
@ -2008,6 +2030,9 @@ class SegmentService:
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
content = args["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
@ -2059,7 +2084,7 @@ class SegmentService:
try:
VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
except Exception as e:
logging.exception("create segment index failed")
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
@ -2070,6 +2095,9 @@ class SegmentService:
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
@ -2142,7 +2170,7 @@ class SegmentService:
# save vector index
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
except Exception as e:
logging.exception("create segment index failed")
logger.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
@ -2153,6 +2181,9 @@ class SegmentService:
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@ -2314,7 +2345,7 @@ class SegmentService:
VectorService.update_segment_vector(args.keywords, segment, dataset)
except Exception as e:
logging.exception("update segment index failed")
logger.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
@ -2334,7 +2365,22 @@ class SegmentService:
if segment.enabled:
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
# Get child chunk IDs before parent segment is deleted
child_node_ids = []
if segment.index_node_id:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
db.session.delete(segment)
# update document word count
assert document.word_count is not None
@ -2344,9 +2390,14 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.filter(
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
segments_info = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@ -2355,16 +2406,36 @@ class SegmentService:
.all()
)
if not segments:
if not segments_info:
return
index_node_ids = [seg.index_node_id for seg in segments]
total_words = sum(seg.word_count for seg in segments)
index_node_ids = [info[0] for info in segments_info]
segment_db_ids = [info[1] for info in segments_info]
total_words = sum(info[2] for info in segments_info if info[2] is not None)
document.word_count -= total_words
# Get child chunk IDs before parent segments are deleted
child_node_ids = []
if index_node_ids:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
document.word_count = (
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
)
db.session.add(document)
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
# Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@ -2372,20 +2443,20 @@ class SegmentService:
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2403,16 +2474,14 @@ class SegmentService:
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2434,20 +2503,12 @@ class SegmentService:
def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk:
assert isinstance(current_user, Account)
lock_name = f"add_child_lock_{segment.id}"
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = (
db.session.query(ChildChunk)
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.count()
)
max_position = (
db.session.query(func.max(ChildChunk.position))
.where(
@ -2476,7 +2537,7 @@ class SegmentService:
try:
VectorService.create_child_chunk_vector(child_chunk, dataset)
except Exception as e:
logging.exception("create child chunk index failed")
logger.exception("create child chunk index failed")
db.session.rollback()
raise ChildChunkIndexingError(str(e))
db.session.commit()
@ -2491,15 +2552,14 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> list[ChildChunk]:
child_chunks = (
db.session.query(ChildChunk)
.where(
assert isinstance(current_user, Account)
child_chunks = db.session.scalars(
select(ChildChunk).where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.all()
)
).all()
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
@ -2551,7 +2611,7 @@ class SegmentService:
VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset)
db.session.commit()
except Exception as e:
logging.exception("update child chunk index failed")
logger.exception("update child chunk index failed")
db.session.rollback()
raise ChildChunkIndexingError(str(e))
return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
@ -2565,6 +2625,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> ChildChunk:
assert current_user is not None
try:
child_chunk.content = content
child_chunk.word_count = len(content)
@ -2575,7 +2637,7 @@ class SegmentService:
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
db.session.commit()
except Exception as e:
logging.exception("update child chunk index failed")
logger.exception("update child chunk index failed")
db.session.rollback()
raise ChildChunkIndexingError(str(e))
return child_chunk
@ -2586,15 +2648,17 @@ class SegmentService:
try:
VectorService.delete_child_chunk_vector(child_chunk, dataset)
except Exception as e:
logging.exception("delete child chunk index failed")
logger.exception("delete child chunk index failed")
db.session.rollback()
raise ChildChunkDeleteIndexError(str(e))
db.session.commit()
@classmethod
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None
):
assert isinstance(current_user, Account)
query = (
select(ChildChunk)
.filter_by(
@ -2610,7 +2674,7 @@ class SegmentService:
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
@ -2647,57 +2711,7 @@ class SegmentService:
return paginated_segments.items, paginated_segments.total
@classmethod
def update_segment_by_id(
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting if high quality
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=user_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# validate and update segment
cls.segment_create_args_validate(segment_data, document)
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
return updated_segment, document
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
@ -2755,19 +2769,13 @@ class DatasetCollectionBindingService:
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = (
db.session.query(
user_list_query = db.session.scalars(
select(
DatasetPermission.account_id,
)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
).where(DatasetPermission.dataset_id == dataset_id)
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
return user_list_query
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):

View File

@ -3,18 +3,30 @@ import os
import requests
class EnterpriseRequest:
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
class BaseRequest:
proxies = {
"http": "",
"https": "",
}
base_url = ""
secret_key = ""
secret_key_header = ""
@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
return response.json()
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
secret_key_header = "Enterprise-Api-Secret-Key"
class EnterprisePluginManagerRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"

View File

@ -0,0 +1,57 @@
import enum
import logging
from pydantic import BaseModel
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.IntEnum):
MODEL = enum.auto()
TOOL = enum.auto()
def to_number(self):
return self.value
class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str
credential_type: PluginCredentialType
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
data["credential_type"] = self.credential_type.to_number()
return data
class CredentialPolicyViolationError(BaseServiceError):
pass
class PluginManagerService:
@classmethod
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
try:
ret = EnterprisePluginManagerRequest.send_request(
"POST", "/check-credential-policy-compliance", json=body.model_dump()
)
if not isinstance(ret, dict) or "result" not in ret:
raise ValueError("Invalid response format from plugin manager API")
except Exception as e:
raise CredentialPolicyViolationError(
f"error occurred while checking credential policy compliance: {e}"
) from e
if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")
logger.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
)

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel):
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
config: AuthorizationConfig | None = None
class ProcessStatusSetting(BaseModel):
@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None
headers: dict | None = None
params: dict | None = None

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@ -11,14 +11,14 @@ class ParentMode(StrEnum):
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
url: str | None = None
emoji: str | None = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
page_icon: NotionIcon | None = None
type: str
@ -40,9 +40,9 @@ class FileInfo(BaseModel):
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
notion_info_list: list[NotionInfo] | None = None
file_info_list: FileInfo | None = None
website_info_list: WebsiteInfo | None = None
class DataSource(BaseModel):
@ -61,20 +61,20 @@ class Segmentation(BaseModel):
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
rules: Rule | None = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
reranking_provider_name: str | None = None
reranking_model_name: str | None = None
class WeightVectorSetting(BaseModel):
@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
vector_setting: WeightVectorSetting | None = None
keyword_setting: WeightKeywordSetting | None = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
weights: Optional[WeightModel] = None
score_threshold: float | None = None
weights: WeightModel | None = None
class MetaDataConfig(BaseModel):
@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel):
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
original_document_id: str | None = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
enabled: bool | None = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
id: str | None = None
content: str
@ -143,13 +143,13 @@ class MetadataArgs(BaseModel):
class MetadataUpdateArgs(BaseModel):
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class MetadataDetail(BaseModel):
id: str
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class DocumentMetadataOperation(BaseModel):

View File

@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -8,7 +7,13 @@ from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
)
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
from core.entities.provider_entities import (
CredentialConfiguration,
CustomModelConfiguration,
ProviderQuotaType,
QuotaConfiguration,
UnaddedModelConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
@ -36,6 +41,11 @@ class CustomConfigurationResponse(BaseModel):
"""
status: CustomConfigurationStatus
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] | None = None
custom_models: list[CustomModelConfiguration] | None = None
can_added_models: list[UnaddedModelConfiguration] | None = None
class SystemConfigurationResponse(BaseModel):
@ -44,7 +54,7 @@ class SystemConfigurationResponse(BaseModel):
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
@ -56,15 +66,15 @@ class ProviderResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
@ -72,7 +82,7 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -97,12 +107,12 @@ class ProviderWithModelsResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -126,7 +136,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -163,7 +173,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError
class AppModelConfigBrokenError(BaseServiceError):
pass
class ProviderNotFoundError(BaseServiceError):
pass

View File

@ -1,6 +1,3 @@
from typing import Optional
class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description

View File

@ -1,12 +1,9 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from urllib.parse import urlparse
import httpx
@ -9,6 +9,7 @@ from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import (
@ -88,7 +89,7 @@ class ExternalDatasetService:
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
except Exception:
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
@ -99,7 +100,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
@ -108,13 +109,14 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
settings = args.get("settings")
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
@ -149,7 +151,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
@ -179,19 +181,29 @@ class ExternalDatasetService:
do http request depending on api bundle
"""
kwargs = {
kwargs: dict[str, Any] = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,
}
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
data=json.dumps(settings.params), files=files, **kwargs
)
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = settings.request_method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {settings.request_method}")
response: httpx.Response = _METHOD_MAP[method_lc](data=json.dumps(settings.params), files=files, **kwargs)
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@ -218,7 +230,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.parse_obj(settings)
return ExternalKnowledgeApiSetting.model_validate(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
@ -265,8 +277,8 @@ class ExternalDatasetService:
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)

View File

@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel):
subscription_plan: str = ""
class PluginManagerModel(BaseModel):
enabled: bool = False
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel):
webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
class FeatureService:
@ -188,6 +193,7 @@ class FeatureService:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:

View File

@ -1,9 +1,8 @@
import hashlib
import os
import uuid
from typing import Any, Literal, Union
from typing import Literal, Union
from flask_login import current_user
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -19,6 +18,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -35,7 +35,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser, Any],
user: Union[Account, EndUser],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:
@ -111,6 +111,9 @@ class FileService:
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name

View File

@ -12,11 +12,13 @@ from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DatasetQuery
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -31,7 +33,7 @@ class HitTestingService:
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
):
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
@ -64,7 +66,7 @@ class HitTestingService:
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 2),
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -77,7 +79,7 @@ class HitTestingService:
)
end = time.perf_counter()
logging.debug("Hit testing retrieve in %s seconds", end - start)
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
@ -96,7 +98,7 @@ class HitTestingService:
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
):
if dataset.provider != "external":
return {
"query": {"content": query},
@ -113,7 +115,7 @@ class HitTestingService:
)
end = time.perf_counter()
logging.debug("External knowledge hit testing retrieve in %s seconds", end - start)
logger.debug("External knowledge hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id

View File

@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -29,9 +29,9 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
conversation_id: str,
first_id: Optional[str],
first_id: str | None,
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
@ -91,11 +91,11 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
conversation_id: str | None = None,
include_ids: list | None = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -112,7 +112,9 @@ class MessageService:
base_query = base_query.where(Message.conversation_id == conversation.id)
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None and len(include_ids) > 0:
if include_ids is not None:
if len(include_ids) == 0:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = base_query.where(Message.id.in_(include_ids))
if last_id:
@ -143,9 +145,9 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
user: Union[Account, EndUser] | None,
rating: str | None,
content: str | None,
):
if not user:
raise ValueError("user cannot be None")
@ -194,7 +196,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
message = (
db.session.query(Message)
.where(
@ -214,7 +216,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
if not user:
raise ValueError("user cannot be None")
@ -227,7 +229,7 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)

View File

@ -1,6 +1,5 @@
import copy
import logging
from typing import Optional
from flask_login import current_user
@ -15,6 +14,8 @@ from services.entities.knowledge_entities.knowledge_entities import (
MetadataOperationData,
)
logger = logging.getLogger(__name__)
class MetadataService:
@staticmethod
@ -90,7 +91,7 @@ class MetadataService:
db.session.commit()
return metadata # type: ignore
except Exception:
logging.exception("Update metadata name failed")
logger.exception("Update metadata name failed")
finally:
redis_client.delete(lock_key)
@ -122,18 +123,18 @@ class MetadataService:
db.session.commit()
return metadata
except Exception:
logging.exception("Delete metadata failed")
logger.exception("Delete metadata failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name.value, "type": "string"},
{"name": BuiltInField.uploader.value, "type": "string"},
{"name": BuiltInField.upload_date.value, "type": "time"},
{"name": BuiltInField.last_update_date.value, "type": "time"},
{"name": BuiltInField.source.value, "type": "string"},
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
]
@staticmethod
@ -151,17 +152,17 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
db.session.commit()
except Exception:
logging.exception("Enable built-in field failed")
logger.exception("Enable built-in field failed")
finally:
redis_client.delete(lock_key)
@ -181,18 +182,18 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
doc_metadata.pop(BuiltInField.last_update_date.value, None)
doc_metadata.pop(BuiltInField.source.value, None)
doc_metadata.pop(BuiltInField.document_name, None)
doc_metadata.pop(BuiltInField.uploader, None)
doc_metadata.pop(BuiltInField.upload_date, None)
doc_metadata.pop(BuiltInField.last_update_date, None)
doc_metadata.pop(BuiltInField.source, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
dataset.built_in_field_enabled = False
db.session.commit()
except Exception:
logging.exception("Disable built-in field failed")
logger.exception("Disable built-in field failed")
finally:
redis_client.delete(lock_key)
@ -209,11 +210,11 @@ class MetadataService:
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
@ -230,12 +231,12 @@ class MetadataService:
db.session.add(dataset_metadata_binding)
db.session.commit()
except Exception:
logging.exception("Update documents metadata failed")
logger.exception("Update documents metadata failed")
finally:
redis_client.delete(lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):

View File

@ -1,7 +1,9 @@
import json
import logging
from json import JSONDecodeError
from typing import Optional, Union
from typing import Union
from sqlalchemy import or_, select
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@ -17,16 +19,16 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import LoadBalancingModelConfig
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model load balancing.
@ -47,7 +49,7 @@ class ModelLoadBalancingService:
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model load balancing.
@ -69,7 +71,7 @@ class ModelLoadBalancingService:
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
@ -100,6 +102,11 @@ class ModelLoadBalancingService:
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
else:
credential_source_type = "custom_model"
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
@ -108,6 +115,10 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
@ -154,7 +165,7 @@ class ModelLoadBalancingService:
try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
@ -169,9 +180,13 @@ class ModelLoadBalancingService:
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
)
token_value = credentials.get(variable)
if isinstance(token_value, str):
credentials[variable] = encrypter.decrypt_token_with_decoding(
token_value,
decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError:
pass
@ -185,6 +200,7 @@ class ModelLoadBalancingService:
"id": load_balancing_config.id,
"name": load_balancing_config.name,
"credentials": credentials,
"credential_id": load_balancing_config.credential_id,
"enabled": load_balancing_config.enabled,
"in_cooldown": in_cooldown,
"ttl": ttl,
@ -195,7 +211,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
) -> dict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@ -280,8 +296,8 @@ class ModelLoadBalancingService:
return inherit_config
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
) -> None:
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
):
"""
Update load balancing configurations.
:param tenant_id: workspace id
@ -289,6 +305,7 @@ class ModelLoadBalancingService:
:param model: model name
:param model_type: model type
:param configs: load balancing configs
:param config_from: predefined-model or custom-model
:return:
"""
# Get all provider configurations of the current workspace
@ -305,16 +322,14 @@ class ModelLoadBalancingService:
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
@ -327,8 +342,38 @@ class ModelLoadBalancingService:
config_id = config.get("id")
name = config.get("name")
credentials = config.get("credentials")
credential_id = config.get("credential_id")
enabled = config.get("enabled")
credential_record: ProviderCredential | ProviderModelCredential | None = None
if credential_id:
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
)
.first()
)
else:
credential_record = (
db.session.query(ProviderModelCredential)
.filter_by(
id=credential_id,
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum.to_origin_model_type(),
)
.first()
)
if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found")
name = credential_record.credential_name
if not name:
raise ValueError("Invalid load balancing config name")
@ -346,11 +391,6 @@ class ModelLoadBalancingService:
load_balancing_config = current_load_balancing_configs_dict[config_id]
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
raise ValueError(f"Load balancing config name {name} already exists")
if credentials:
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
@ -380,36 +420,45 @@ class ModelLoadBalancingService:
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.name == name:
raise ValueError(f"Load balancing config name {name} already exists")
if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=credential_record.credential_name,
encrypted_config=credential_record.encrypted_config,
credential_id=credential_id,
credential_source_type=credential_source,
)
else:
if not credentials:
raise ValueError("Invalid load balancing config credentials")
if not credentials:
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
)
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
)
# create load balancing config
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
)
# create load balancing config
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
)
db.session.add(load_balancing_model_config)
db.session.commit()
@ -429,8 +478,8 @@ class ModelLoadBalancingService:
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
config_id: str | None = None,
):
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@ -487,9 +536,9 @@ class ModelLoadBalancingService:
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
validate: bool = True,
) -> dict:
):
"""
Validate custom credentials.
:param tenant_id: workspace id
@ -557,7 +606,7 @@ class ModelLoadBalancingService:
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
"""
Clear credentials cache.
:param tenant_id: workspace id

View File

@ -1,7 +1,6 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
@ -16,6 +15,7 @@ from services.entities.model_provider_entities import (
SimpleProviderEntityResponse,
SystemConfigurationResponse,
)
from services.errors.app_model_config import ProviderNotFoundError
logger = logging.getLogger(__name__)
@ -25,10 +25,33 @@ class ModelProviderService:
Model Provider Service
"""
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
def _get_provider_configuration(self, tenant_id: str, provider: str):
"""
Get provider configuration or raise exception if not found.
Args:
tenant_id: Workspace identifier
provider: Provider name
Returns:
Provider configuration instance
Raises:
ProviderNotFoundError: If provider doesn't exist
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ProviderNotFoundError(f"Provider {provider} does not exist.")
return provider_configuration
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
"""
get provider list.
@ -46,6 +69,10 @@ class ModelProviderService:
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
provider_response = ProviderResponse(
tenant_id=tenant_id,
provider=provider_configuration.provider.provider,
@ -63,7 +90,12 @@ class ModelProviderService:
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
else CustomConfigurationStatus.NO_CONFIGURE,
current_credential_id=getattr(provider_config, "current_credential_id", None),
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
@ -82,8 +114,8 @@ class ModelProviderService:
For the model provider page,
only supports passing in a single provider to query the list of supported models.
:param tenant_id:
:param provider:
:param tenant_id: workspace id
:param provider: provider name
:return:
"""
# Get all provider configurations of the current workspace
@ -95,100 +127,109 @@ class ModelProviderService:
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
"""
get provider credentials.
"""
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
validate provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_configuration.custom_credentials_validate(credentials)
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
save custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials
:param credential_id: credential id, if not provided, return current used credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom provider credentials.
provider_configuration.add_or_update_custom_credentials(credentials)
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
remove custom provider config.
validate provider credentials before saving.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
:param credential_name: credential name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_provider_credential(credentials, credential_name)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
def update_provider_credential(
self,
tenant_id: str,
provider: str,
credentials: dict,
credential_id: str,
credential_name: str | None,
) -> None:
"""
get model credentials.
update a saved provider credential (by credential_id).
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials dict
:param credential_id: credential id
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_provider_credential(
credential_id=credential_id,
credentials=credentials,
credential_name=credential_name,
)
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
remove a saved provider credential (by credential_id).
:param tenant_id: workspace id
:param provider: provider name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_provider_credential(credential_id=credential_id)
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
:param tenant_id: workspace id
:param provider: provider name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> dict | None:
"""
Retrieve model-specific credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: Optional credential ID, uses current if not provided
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get model custom credentials from ProviderModel if exists
return provider_configuration.get_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential( # type: ignore
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def model_credentials_validate(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
"""
validate model credentials.
@ -196,49 +237,120 @@ class ModelProviderService:
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:param credentials: model credentials dict
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Validate model credentials
provider_configuration.custom_model_credentials_validate(
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def save_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
save model credentials.
create and save model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:param credentials: model credentials dict
:param credential_name: credential name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom model credentials
provider_configuration.add_or_update_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_custom_model_credential(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials,
credential_name=credential_name,
)
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
def update_model_credential(
self,
tenant_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
credential_id: str,
credential_name: str | None,
) -> None:
"""
update model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:param credential_id: credential id
:param credential_name: credential name
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_custom_model_credential(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials,
credential_id=credential_id,
credential_name=credential_name,
)
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
"""
remove model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def switch_active_custom_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
):
"""
switch model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def add_model_credential_to_model_list(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
):
"""
add model credentials to model list.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.add_model_credential_to_model(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
remove model credentials.
@ -248,16 +360,8 @@ class ModelProviderService:
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom model credentials
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
@ -271,7 +375,7 @@ class ModelProviderService:
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
@ -282,9 +386,6 @@ class ModelProviderService:
if model.deprecated:
continue
if model.status != ModelStatus.ACTIVE:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
@ -331,13 +432,7 @@ class ModelProviderService:
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_configuration = self._get_provider_configuration(tenant_id, provider)
# fetch credentials
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
@ -351,7 +446,7 @@ class ModelProviderService:
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
"""
get default model of model type.
@ -383,7 +478,7 @@ class ModelProviderService:
logger.debug("get_default_model_of_model_type error: %s", e)
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
"""
update default model of model type.
@ -400,7 +495,7 @@ class ModelProviderService:
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
) -> tuple[bytes | None, str | None]:
"""
get model provider icon.
@ -415,7 +510,7 @@ class ModelProviderService:
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
"""
switch preferred provider.
@ -424,21 +519,15 @@ class ModelProviderService:
:param preferred_provider_type: preferred provider type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = self._get_provider_configuration(tenant_id, provider)
# Convert preferred_provider_type to ProviderType
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model.
@ -448,18 +537,10 @@ class ModelProviderService:
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model.
@ -469,13 +550,5 @@ class ModelProviderService:
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))

View File

@ -0,0 +1,94 @@
import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.model import OAuthProviderApp
from services.account_service import AccountService
class OAuthGrantType(enum.StrEnum):
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token"
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
class OAuthServerService:
@staticmethod
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
return session.execute(query).scalar_one_or_none()
@staticmethod
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
code = str(uuid.uuid4())
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
return code
@staticmethod
def sign_oauth_access_token(
grant_type: OAuthGrantType,
code: str = "",
client_id: str = "",
refresh_token: str = "",
) -> tuple[str, str]:
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid code")
# delete code
redis_client.delete(redis_key)
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
return access_token, refresh_token
case OAuthGrantType.REFRESH_TOKEN:
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid refresh token")
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
return access_token, refresh_token
@staticmethod
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
return token
@staticmethod
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
return token
@staticmethod
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
return None
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
@ -15,7 +15,7 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@ -134,17 +134,26 @@ class OpsService:
# get project url
if tracing_provider in ("arize", "phoenix"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
elif tracing_provider == "langfuse":
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = f"{tracing_config.get('host')}/project/{project_key}"
try:
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = f"{tracing_config.get('host')}/project/{project_key}"
except Exception:
project_url = None
elif tracing_provider in ("langsmith", "opik"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
try:
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
except Exception:
project_url = None
else:
project_url = None
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
def migrate(cls):
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
@ -26,7 +26,7 @@ class PluginDataMigration:
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
def migrate_datasets(cls):
table_name = "datasets"
provider_column_name = "embedding_model_provider"
@ -126,9 +126,7 @@ limit 1000"""
)
@classmethod
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
@ -175,7 +173,7 @@ limit 1000"""
# update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id))
except Exception as e:
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(

View File

@ -10,7 +10,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
@ -26,7 +26,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
@ -54,7 +54,7 @@ class PluginAutoUpgradeService:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
from typing import Any
from uuid import uuid4
import click
@ -33,7 +33,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int) -> None:
def extract_plugins(cls, filepath: str, workers: int):
"""
Migrate plugin.
"""
@ -55,7 +55,7 @@ class PluginMigration:
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
with flask_app.app_context():
nonlocal handled_tenant_count
try:
@ -99,6 +99,7 @@ class PluginMigration:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
@ -255,7 +256,7 @@ class PluginMigration:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
@ -280,7 +281,7 @@ class PluginMigration:
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
"""
Fetch plugin unique identifier using plugin id.
"""
@ -291,7 +292,7 @@ class PluginMigration:
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
"""
Extract unique plugins.
"""
@ -328,7 +329,7 @@ class PluginMigration:
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
"""
Install plugins.
"""
@ -348,7 +349,7 @@ class PluginMigration:
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
def install(tenant_id: str, plugin_ids: list[str]):
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import Optional
from pydantic import BaseModel
@ -46,11 +45,11 @@ class PluginService:
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
Fetch the latest plugin version
"""
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
result: dict[str, PluginService.LatestPluginCache | None] = {}
try:
cache_not_exists = []
@ -109,7 +108,7 @@ class PluginService:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
"""
Check the plugin installation scope
"""
@ -144,7 +143,7 @@ class PluginService:
return manager.get_debugging_key(tenant_id)
@staticmethod
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
List the latest versions of the plugins
"""

View File

@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@ -14,12 +13,12 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_builtin(language)
return result
@ -28,7 +27,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return result
@classmethod
def _get_builtin_data(cls) -> dict:
def _get_builtin_data(cls):
"""
Get builtin data.
:return:
@ -44,7 +43,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
def fetch_recommended_apps_from_builtin(cls, language: str):
"""
Fetch recommended apps from builtin.
:param language: language
@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID

View File

@ -1,4 +1,4 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
from extensions.ext_database import db
@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_db(language)
return result
@ -25,24 +25,20 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
def fetch_recommended_apps_from_db(cls, language: str):
"""
Fetch recommended apps from db.
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()
categories = set()
recommended_apps_result = []
@ -74,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID

View File

@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
"""Interface for recommend app retrieval."""
@abstractmethod
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
raise NotImplementedError
@abstractmethod

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@ -24,7 +23,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
return result
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
try:
result = self.fetch_recommended_apps_from_dify_official(language)
except Exception as e:
@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID
@ -51,7 +50,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
def fetch_recommended_apps_from_dify_official(cls, language: str):
"""
Fetch recommended apps from dify official.
:param language: language

View File

@ -1,12 +1,10 @@
from typing import Optional
from configs import dify_config
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str) -> dict:
def get_recommended_apps_and_categories(cls, language: str):
"""
Get recommended apps and categories.
:param language: language
@ -15,7 +13,7 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result = retrieval_instance.get_recommended_apps_and_categories(language)
if not result.get("recommended_apps") and language != "en-US":
if not result.get("recommended_apps"):
result = (
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
"en-US"
@ -25,7 +23,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
"""
Get recommend app detail.
:param app_id: app id

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -11,7 +11,7 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
@ -32,7 +32,7 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
@ -62,7 +62,7 @@ class SavedMessageService:
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (

View File

@ -1,8 +1,7 @@
import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@ -25,46 +24,41 @@ class TagService:
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
if not tags:
return []
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@ -117,7 +111,7 @@ class TagService:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -3,8 +3,9 @@ import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
@ -190,11 +191,14 @@ class BuiltinToolManageService:
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")
@ -219,8 +223,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -246,11 +250,14 @@ class BuiltinToolManageService:
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")
@ -278,9 +285,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
@ -453,7 +460,7 @@ class BuiltinToolManageService:
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
@ -467,7 +474,7 @@ class BuiltinToolManageService:
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -492,7 +499,7 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -546,65 +553,64 @@ class BuiltinToolManageService:
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = []
result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.entity.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
result.append(user_builtin_provider)
except Exception as e:
raise e
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
with Session(db.engine, autoflush=False) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
@ -659,8 +665,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client

View File

@ -1,7 +1,7 @@
import hashlib
import json
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@ -27,6 +27,36 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
Returns:
Dictionary with all headers encrypted
"""
if not headers:
return {}
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return cast(dict[str, str], encrypter_instance.encrypt(headers))
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
@ -61,6 +91,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@ -83,6 +114,12 @@ class MCPToolManageService:
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -95,6 +132,7 @@ class MCPToolManageService:
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
@ -118,9 +156,21 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
@ -172,6 +222,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@ -207,6 +258,13 @@ class MCPToolManageService:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Encrypt headers
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
mcp_provider.encrypted_headers = None
db.session.commit()
except IntegrityError as e:
db.session.rollback()
@ -226,10 +284,10 @@ class MCPToolManageService:
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
@ -242,6 +300,12 @@ class MCPToolManageService:
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(
server_url,
@ -249,6 +313,9 @@ class MCPToolManageService:
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from yarl import URL
@ -94,7 +94,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
@ -128,7 +128,7 @@ class ToolTransformService:
)
}
for name, value in schema.items():
for name in schema:
if result.masked_credentials:
result.masked_credentials[name] = ""
@ -237,6 +237,10 @@ class ToolTransformService:
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
)
@staticmethod

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -37,7 +37,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
@ -103,7 +103,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
"""
Update a workflow tool.
:param user_id: the user id
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
@ -217,7 +219,7 @@ class WorkflowToolManageService:
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
@ -233,7 +235,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -249,7 +251,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -265,7 +267,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -13,13 +12,13 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
@ -27,7 +26,7 @@ class VectorService:
if doc_form == IndexType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
_logger.warning(
logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
segment.id,
@ -79,7 +78,7 @@ class VectorService:
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
# update segment index task
# format new index

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -19,11 +19,11 @@ class WebConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
pinned: bool | None = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -60,7 +60,7 @@ class WebConversationService:
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
@ -92,7 +92,7 @@ class WebConversationService:
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (

View File

@ -1,7 +1,7 @@
import enum
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from typing import Any
from werkzeug.exceptions import NotFound, Unauthorized
@ -42,7 +42,7 @@ class WebAppAuthService:
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
return cast(Account, account)
return account
@classmethod
def login(cls, account: Account) -> str:
@ -63,7 +63,7 @@ class WebAppAuthService:
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
cls, account: Account | None = None, email: str | None = None, language: str = "en-US"
):
email = account.email if account else email
if email is None:
@ -82,7 +82,7 @@ class WebAppAuthService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -113,7 +113,7 @@ class WebAppAuthService:
@classmethod
def _get_account_jwt_token(cls, account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
@ -130,7 +130,7 @@ class WebAppAuthService:
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.

View File

@ -1,7 +1,7 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import requests
from flask_login import current_user
@ -21,9 +21,9 @@ class CrawlOptions:
limit: int = 1
crawl_sub_pages: bool = False
only_main_content: bool = False
includes: Optional[str] = None
excludes: Optional[str] = None
max_depth: Optional[int] = None
includes: str | None = None
excludes: str | None = None
max_depth: int | None = None
use_sitemap: bool = True
def get_include_paths(self) -> list[str]:
@ -132,7 +132,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict) -> None:
def document_create_args_validate(cls, args: dict):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
from core.app.app_config.entities import (
DatasetEntity,
@ -18,6 +18,7 @@ from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
@ -64,7 +65,7 @@ class WorkflowConverter:
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
@ -202,7 +203,7 @@ class WorkflowConverter:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@ -217,7 +218,7 @@ class WorkflowConverter:
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
def _convert_to_start_node(self, variables: list[VariableEntity]):
"""
Convert to Start Node
:param variables: list of variables
@ -278,7 +279,7 @@ class WorkflowConverter:
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
},
}
@ -326,7 +327,7 @@ class WorkflowConverter:
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
) -> dict | None:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
@ -382,9 +383,9 @@ class WorkflowConverter:
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
file_upload: FileUploadConfig | None = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
):
"""
Convert to LLM Node
:param original_app_mode: original app mode
@ -402,7 +403,7 @@ class WorkflowConverter:
)
role_prefix = None
prompts: Optional[Any] = None
prompts: Any | None = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
@ -420,7 +421,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
if not template:
prompts = []
else:
@ -457,7 +462,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
template = self._replace_template_variables(
template=template,
variables=start_node["data"]["variables"],
@ -467,6 +476,9 @@ class WorkflowConverter:
prompts = {"text": template}
prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
@ -550,7 +562,7 @@ class WorkflowConverter:
return template
def _convert_to_end_node(self) -> dict:
def _convert_to_end_node(self):
"""
Convert to End Node
:return:
@ -566,7 +578,7 @@ class WorkflowConverter:
},
}
def _convert_to_answer_node(self) -> dict:
def _convert_to_answer_node(self):
"""
Convert to Answer Node
:return:
@ -578,7 +590,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
def _create_edge(self, source: str, target: str):
"""
Create Edge
:param source: source node id
@ -587,7 +599,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
def _append_node(self, graph: dict, node: dict):
"""
Append Node to Graph
@ -606,7 +618,7 @@ class WorkflowConverter:
:param app_model: App instance
:return: AppMode
"""
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return AppMode.WORKFLOW
else:
return AppMode.ADVANCED_CHAT

View File

@ -23,7 +23,7 @@ class WorkflowAppService:
limit: int = 20,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
) -> dict:
):
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
:param session: SQLAlchemy session

View File

@ -28,7 +28,7 @@ from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
_logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
@ -67,7 +67,7 @@ class DraftVarLoader(VariableLoader):
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
):
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
@ -117,7 +117,7 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
def __init__(self, session: Session):
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
@ -242,7 +242,7 @@ class WorkflowDraftVariableService:
if conv_var is None:
self._session.delete(instance=variable)
self._session.flush()
_logger.warning(
logger.warning(
"Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name
)
return None
@ -263,12 +263,12 @@ class WorkflowDraftVariableService:
if variable.node_execution_id is None:
self._session.delete(instance=variable)
self._session.flush()
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
variable.id,
variable.name,
@ -351,7 +351,7 @@ class WorkflowDraftVariableService:
return None
segment = draft_var.get_value()
if not isinstance(segment, StringSegment):
_logger.warning(
logger.warning(
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
app_id,
draft_var.id,
@ -438,7 +438,7 @@ def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
):
if not draft_vars:
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
@ -681,7 +681,7 @@ class DraftVariableSaver:
draft_vars = []
for name, value in output.items():
if not self._should_variable_be_saved(name):
_logger.debug(
logger.debug(
"Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s",
name,
self._node_type,

View File

@ -1,6 +1,5 @@
import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy.orm import sessionmaker
@ -80,7 +79,7 @@ class WorkflowRunService:
last_id=last_id,
)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail

View File

@ -2,10 +2,10 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
@ -37,23 +37,15 @@ from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
from .chatflow_memory_service import ChatflowMemoryService
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
WorkflowDraftVariableService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
class WorkflowService:
@ -89,20 +81,21 @@ class WorkflowService:
)
def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
.where(
stmt = select(
exists().where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.count()
) > 0
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
"""
Get draft workflow
"""
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
@ -117,8 +110,10 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
# fetch published workflow by workflow_id
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
"""
fetch published workflow by workflow_id
"""
workflow = (
db.session.query(Workflow)
.where(
@ -137,7 +132,7 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
def get_published_workflow(self, app_model: App) -> Workflow | None:
"""
Get published workflow
"""
@ -202,7 +197,7 @@ class WorkflowService:
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@ -270,6 +265,12 @@ class WorkflowService:
if not draft_workflow:
raise ValueError("No valid workflow found.")
# Validate credentials before publishing, for credential policy check
from services.feature_service import FeatureService
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@ -294,6 +295,260 @@ class WorkflowService:
# return new workflow
return workflow
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
"""
Validate all credentials in workflow nodes before publishing.
:param workflow: The workflow to validate
:raises ValueError: If any credentials violate policy compliance
"""
graph_dict = workflow.graph_dict
nodes = graph_dict.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
node_type = node_data.get("type")
node_id = node.get("id", "unknown")
try:
# Extract and validate credentials based on node type
if node_type == "tool":
credential_id = node_data.get("credential_id")
provider = node_data.get("provider_id")
if provider:
if credential_id:
# Check specific credential
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=credential_id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
)
else:
# Check default workspace credential for this provider
self._check_default_tool_credential(workflow.tenant_id, provider)
elif node_type == "agent":
agent_params = node_data.get("agent_parameters", {})
model_config = agent_params.get("model", {}).get("value", {})
if model_config.get("provider") and model_config.get("model"):
self._validate_llm_model_config(
workflow.tenant_id, model_config["provider"], model_config["model"]
)
# Validate load balancing credentials for agent model if load balancing is enabled
agent_model_node_data = {"model": model_config}
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
# Validate agent tools
tools = agent_params.get("tools", {}).get("value", [])
for tool in tools:
# Agent tools store provider in provider_name field
provider = tool.get("provider_name")
credential_id = tool.get("credential_id")
if provider:
if credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
else:
self._check_default_tool_credential(workflow.tenant_id, provider)
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")
if provider and model_name:
# Validate that the provider+model combination can fetch valid credentials
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
# Validate load balancing credentials if load balancing is enabled
self._validate_load_balancing_credentials(workflow, node_data, node_id)
else:
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
except Exception as e:
if isinstance(e, ValueError):
raise e
else:
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials.
This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
"""
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
# Get model instance to validate provider+model combination
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
)
# The ModelInstance constructor will automatically check credential policy compliance
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised
except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
)
def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
"""
Check credential policy compliance for the default workspace credential of a tool provider.
This method finds the default credential for the given provider and validates it.
Uses the same fallback logic as runtime to handle deauthorized credentials.
:param tenant_id: The tenant ID
:param provider: The tool provider name
:raises ValueError: If no default credential exists or if it fails policy compliance
"""
try:
from models.tools import BuiltinToolProvider
# Use the same fallback logic as runtime: get the first available credential
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
default_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if not default_provider:
raise ValueError("No default credential found")
# Check credential policy compliance using the default credential ID
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=default_provider.id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.
:param workflow: The workflow being validated
:param node_data: The node data containing model configuration
:param node_id: The node ID for error reporting
:raises ValueError: If load balancing credentials violate policy compliance
"""
# Extract model configuration
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")
if not provider or not model_name:
return # No model config to validate
# Check if this model has load balancing enabled
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
# Get all load balancing configurations for this model
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
# Validate each load balancing configuration
try:
for config in load_balancing_configs:
if config.get("credential_id"):
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
config["credential_id"], provider, PluginCredentialType.MODEL
)
except Exception as e:
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
"""
Check if load balancing is enabled for a specific model.
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: True if load balancing is enabled, False otherwise
"""
try:
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
# Get provider configurations
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
return False
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=ModelType.LLM,
model=model_name,
)
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled
except Exception:
# If we can't determine the status, assume load balancing is not enabled
return False
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
"""
Get all load balancing configurations for a model.
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: List of load balancing configuration dictionaries
"""
try:
from services.model_load_balancing_service import ModelLoadBalancingService
model_load_balancing_service = ModelLoadBalancingService()
_, configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=model_name,
model_type="llm", # Load balancing is primarily used for LLM models
config_from="predefined-model", # Check both predefined and custom models
)
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs
return [config for config in all_configs if config.get("credential_id")]
except Exception:
# If we can't get the configurations, return empty list
# This will prevent validation errors from breaking the workflow
return []
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configs
@ -308,7 +563,7 @@ class WorkflowService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
"""
Get default config of node.
:param node_type: node type
@ -509,10 +764,10 @@ class WorkflowService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node = e._node
node = e.node
run_succeeded = False
node_run_result = None
error = e._error
error = e.error
# Create a NodeExecution domain model
node_execution = WorkflowNodeExecution(
@ -568,7 +823,7 @@ class WorkflowService:
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
@ -583,12 +838,12 @@ class WorkflowService:
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
if app_model.mode == AppMode.ADVANCED_CHAT.value:
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
@ -597,7 +852,7 @@ class WorkflowService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes

View File

@ -12,7 +12,7 @@ class WorkspaceService:
def get_tenant_info(cls, tenant: Tenant):
if not tenant:
return None
tenant_info = {
tenant_info: dict[str, object] = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,