mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/memory-orchestration-be
# Conflicts: # api/core/app/apps/advanced_chat/app_runner.py # api/core/prompt/entities/advanced_prompt_entities.py # api/core/variables/segments.py
This commit is contained in:
@ -5,7 +5,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
@ -37,7 +37,6 @@ from services.billing_service import BillingService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountLoginError,
|
||||
AccountNotFoundError,
|
||||
AccountNotLinkTenantError,
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
@ -65,7 +64,13 @@ from tasks.mail_owner_transfer_task import (
|
||||
send_old_owner_transfer_notify_email_task,
|
||||
send_owner_transfer_confirm_task,
|
||||
)
|
||||
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
|
||||
from tasks.mail_reset_password_task import (
|
||||
send_reset_password_mail_task,
|
||||
send_reset_password_mail_task_when_account_not_exist,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
@ -80,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
class AccountService:
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
|
||||
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
|
||||
)
|
||||
email_code_account_deletion_rate_limiter = RateLimiter(
|
||||
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
|
||||
@ -93,6 +99,7 @@ class AccountService:
|
||||
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
|
||||
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
|
||||
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
|
||||
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
|
||||
|
||||
@staticmethod
|
||||
def _get_refresh_token_key(refresh_token: str) -> str:
|
||||
@ -103,14 +110,14 @@ class AccountService:
|
||||
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
|
||||
|
||||
@staticmethod
|
||||
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _store_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
|
||||
redis_client.setex(
|
||||
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||
|
||||
@ -143,8 +150,11 @@ class AccountService:
|
||||
if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return cast(Account, account)
|
||||
# NOTE: make sure account is accessible outside of a db session
|
||||
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
|
||||
db.session.refresh(account)
|
||||
db.session.close()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def get_account_jwt_token(account: Account) -> str:
|
||||
@ -161,12 +171,12 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
raise AccountLoginError("Account is banned.")
|
||||
@ -189,7 +199,7 @@ class AccountService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return cast(Account, account)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
@ -209,6 +219,7 @@ class AccountService:
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@ -217,9 +228,9 @@ class AccountService:
|
||||
email: str,
|
||||
name: str,
|
||||
interface_language: str,
|
||||
password: Optional[str] = None,
|
||||
password: str | None = None,
|
||||
interface_theme: str = "light",
|
||||
is_setup: Optional[bool] = False,
|
||||
is_setup: bool | None = False,
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
if not FeatureService.get_system_features().is_allow_register and not is_setup:
|
||||
@ -240,6 +251,8 @@ class AccountService:
|
||||
account.name = name
|
||||
|
||||
if password:
|
||||
valid_password(password)
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
@ -263,7 +276,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def create_account_and_tenant(
|
||||
email: str, name: str, interface_language: str, password: Optional[str] = None
|
||||
email: str, name: str, interface_language: str, password: str | None = None
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
account = AccountService.create_account(
|
||||
@ -288,7 +301,9 @@ class AccountService:
|
||||
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
|
||||
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
|
||||
|
||||
raise EmailCodeAccountDeletionRateLimitExceededError()
|
||||
raise EmailCodeAccountDeletionRateLimitExceededError(
|
||||
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
|
||||
)
|
||||
|
||||
send_account_deletion_verification_code.delay(to=email, code=code)
|
||||
|
||||
@ -306,16 +321,16 @@ class AccountService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_account(account: Account) -> None:
|
||||
def delete_account(account: Account):
|
||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account):
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = (
|
||||
account_integrate: AccountIntegrate | None = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
)
|
||||
|
||||
@ -332,13 +347,13 @@ class AccountService:
|
||||
db.session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
logging.info("Account %s linked %s account %s.", account.id, provider, open_id)
|
||||
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
|
||||
except Exception as e:
|
||||
logging.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
|
||||
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
def close_account(account: Account):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
db.session.commit()
|
||||
@ -346,6 +361,7 @@ class AccountService:
|
||||
@staticmethod
|
||||
def update_account(account, **kwargs):
|
||||
"""Update account fields"""
|
||||
account = db.session.merge(account)
|
||||
for field, value in kwargs.items():
|
||||
if hasattr(account, field):
|
||||
setattr(account, field, value)
|
||||
@ -367,7 +383,7 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str) -> None:
|
||||
def update_login_info(account: Account, *, ip_address: str):
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
@ -375,7 +391,7 @@ class AccountService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
|
||||
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
|
||||
if ip_address:
|
||||
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||
|
||||
@ -391,7 +407,7 @@ class AccountService:
|
||||
return TokenPair(access_token=access_token, refresh_token=refresh_token)
|
||||
|
||||
@staticmethod
|
||||
def logout(*, account: Account) -> None:
|
||||
def logout(*, account: Account):
|
||||
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
||||
if refresh_token:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
@ -423,9 +439,10 @@ class AccountService:
|
||||
@classmethod
|
||||
def send_reset_password_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
is_allow_register: bool = False,
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
@ -434,26 +451,67 @@ class AccountService:
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import PasswordResetRateLimitExceededError
|
||||
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_reset_password_token(account_email, account)
|
||||
|
||||
send_reset_password_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
if account:
|
||||
send_reset_password_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
else:
|
||||
send_reset_password_mail_task_when_account_not_exist.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
is_allow_register=is_allow_register,
|
||||
)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def send_email_register_email(
|
||||
cls,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
|
||||
if cls.email_register_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
|
||||
|
||||
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_email_register_token(account_email)
|
||||
|
||||
if account:
|
||||
send_email_register_mail_task_when_account_exist.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
account_name=account.name,
|
||||
)
|
||||
|
||||
else:
|
||||
send_email_register_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
cls.email_register_rate_limiter.increment_rate_limit(account_email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def send_change_email_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
old_email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
old_email: str | None = None,
|
||||
language: str = "en-US",
|
||||
phase: Optional[str] = None,
|
||||
phase: str | None = None,
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
@ -464,7 +522,7 @@ class AccountService:
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError()
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
|
||||
@ -480,8 +538,8 @@ class AccountService:
|
||||
@classmethod
|
||||
def send_change_email_completed_notify_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
@ -496,10 +554,10 @@ class AccountService:
|
||||
@classmethod
|
||||
def send_owner_transfer_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
workspace_name: Optional[str] = "",
|
||||
workspace_name: str | None = "",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
@ -508,7 +566,7 @@ class AccountService:
|
||||
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
|
||||
|
||||
raise OwnerTransferRateLimitExceededError()
|
||||
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_owner_transfer_token(account_email, account)
|
||||
workspace_name = workspace_name or ""
|
||||
@ -525,10 +583,10 @@ class AccountService:
|
||||
@classmethod
|
||||
def send_old_owner_transfer_notify_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
workspace_name: Optional[str] = "",
|
||||
workspace_name: str | None = "",
|
||||
new_owner_email: str = "",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
@ -546,10 +604,10 @@ class AccountService:
|
||||
@classmethod
|
||||
def send_new_owner_transfer_notify_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
workspace_name: Optional[str] = "",
|
||||
workspace_name: str | None = "",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
@ -566,8 +624,8 @@ class AccountService:
|
||||
def generate_reset_password_token(
|
||||
cls,
|
||||
email: str,
|
||||
account: Optional[Account] = None,
|
||||
code: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
code: str | None = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
@ -578,13 +636,26 @@ class AccountService:
|
||||
)
|
||||
return code, token
|
||||
|
||||
@classmethod
|
||||
def generate_email_register_token(
|
||||
cls,
|
||||
email: str,
|
||||
code: str | None = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
additional_data["code"] = code
|
||||
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
|
||||
return code, token
|
||||
|
||||
@classmethod
|
||||
def generate_change_email_token(
|
||||
cls,
|
||||
email: str,
|
||||
account: Optional[Account] = None,
|
||||
code: Optional[str] = None,
|
||||
old_email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
code: str | None = None,
|
||||
old_email: str | None = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
@ -600,8 +671,8 @@ class AccountService:
|
||||
def generate_owner_transfer_token(
|
||||
cls,
|
||||
email: str,
|
||||
account: Optional[Account] = None,
|
||||
code: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
code: str | None = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
@ -616,6 +687,10 @@ class AccountService:
|
||||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def revoke_email_register_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "email_register")
|
||||
|
||||
@classmethod
|
||||
def revoke_change_email_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "change_email")
|
||||
@ -625,22 +700,26 @@ class AccountService:
|
||||
TokenManager.revoke_token(token, "owner_transfer")
|
||||
|
||||
@classmethod
|
||||
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_reset_password_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_email_register_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_register")
|
||||
|
||||
@classmethod
|
||||
def get_change_email_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "change_email")
|
||||
|
||||
@classmethod
|
||||
def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "owner_transfer")
|
||||
|
||||
@classmethod
|
||||
def send_email_code_login_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
account: Account | None = None,
|
||||
email: str | None = None,
|
||||
language: str = "en-US",
|
||||
):
|
||||
email = account.email if account else email
|
||||
@ -649,7 +728,7 @@ class AccountService:
|
||||
if cls.email_code_login_rate_limiter.is_rate_limited(email):
|
||||
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
|
||||
|
||||
raise EmailCodeLoginRateLimitExceededError()
|
||||
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
|
||||
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
@ -664,7 +743,7 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_code_login")
|
||||
|
||||
@classmethod
|
||||
@ -698,7 +777,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_login_error_rate_limit(email: str) -> None:
|
||||
def add_login_error_rate_limit(email: str):
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -727,7 +806,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||
def add_forgot_password_error_rate_limit(email: str):
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -735,6 +814,16 @@ class AccountService:
|
||||
count = int(count) + 1
|
||||
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_email_register_error_rate_limit(email: str) -> None:
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
count = 0
|
||||
count = int(count) + 1
|
||||
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
||||
@ -754,9 +843,27 @@ class AccountService:
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_email_register_error_rate_limit(email: str) -> bool:
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
return False
|
||||
count = int(count)
|
||||
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_change_email_error_rate_limit(email: str) -> None:
|
||||
def reset_email_register_error_rate_limit(email: str):
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_change_email_error_rate_limit(email: str):
|
||||
key = f"change_email_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -784,7 +891,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_owner_transfer_error_rate_limit(email: str) -> None:
|
||||
def add_owner_transfer_error_rate_limit(email: str):
|
||||
key = f"owner_transfer_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -858,7 +965,7 @@ class AccountService:
|
||||
|
||||
class TenantService:
|
||||
@staticmethod
|
||||
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
|
||||
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
|
||||
"""Create tenant"""
|
||||
if (
|
||||
not FeatureService.get_system_features().is_allow_create_workspace
|
||||
@ -889,9 +996,7 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(
|
||||
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
|
||||
):
|
||||
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
@ -925,7 +1030,7 @@ class TenantService:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
||||
logging.error("Tenant %s has already an owner.", tenant.id)
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
@ -963,7 +1068,7 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
|
||||
def switch_tenant(account: Account, tenant_id: str | None = None):
|
||||
"""Switch the current workspace for the account"""
|
||||
|
||||
# Ensure tenant_id is provided
|
||||
@ -1045,7 +1150,7 @@ class TenantService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
|
||||
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
@ -1060,7 +1165,7 @@ class TenantService:
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
@ -1080,7 +1185,7 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
@ -1095,7 +1200,7 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
@ -1122,10 +1227,10 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
def get_custom_config(tenant_id: str):
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return cast(dict, tenant.custom_config_dict)
|
||||
return tenant.custom_config_dict
|
||||
|
||||
@staticmethod
|
||||
def is_owner(account: Account, tenant: Tenant) -> bool:
|
||||
@ -1143,7 +1248,7 @@ class RegisterService:
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
@ -1177,7 +1282,7 @@ class RegisterService:
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
logging.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
raise ValueError(f"Setup failed: {e}")
|
||||
|
||||
@classmethod
|
||||
@ -1185,13 +1290,13 @@ class RegisterService:
|
||||
cls,
|
||||
email,
|
||||
name,
|
||||
password: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
status: Optional[AccountStatus] = None,
|
||||
is_setup: Optional[bool] = False,
|
||||
create_workspace_required: Optional[bool] = True,
|
||||
password: str | None = None,
|
||||
open_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
language: str | None = None,
|
||||
status: AccountStatus | None = None,
|
||||
is_setup: bool | None = False,
|
||||
create_workspace_required: bool | None = True,
|
||||
) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
@ -1222,15 +1327,15 @@ class RegisterService:
|
||||
db.session.commit()
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
logging.exception("Register failed")
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError("Workspace is not allowed to create.")
|
||||
except AccountRegisterError as are:
|
||||
db.session.rollback()
|
||||
logging.exception("Register failed")
|
||||
logger.exception("Register failed")
|
||||
raise are
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception("Register failed")
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
return account
|
||||
@ -1308,10 +1413,8 @@ class RegisterService:
|
||||
redis_client.delete(cls._get_invitation_token_key(token))
|
||||
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: Optional[str], email: str, token: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
|
||||
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
@ -1348,9 +1451,9 @@ class RegisterService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_by_token(
|
||||
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
|
||||
) -> Optional[dict[str, str]]:
|
||||
def get_invitation_by_token(
|
||||
cls, token: str, workspace_id: str | None = None, email: str | None = None
|
||||
) -> dict[str, str] | None:
|
||||
if workspace_id is not None and email is not None:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
|
||||
@ -17,7 +17,7 @@ from models.model import AppMode
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
def get_prompt(cls, args: dict):
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
@ -29,17 +29,17 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
@ -70,10 +70,10 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
@ -10,13 +9,14 @@ from core.plugin.impl.agent import PluginAgentClient
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
@ -35,7 +35,7 @@ class AgentService:
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Optional[Message] = (
|
||||
message: Message | None = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
@ -61,14 +61,15 @@ class AgentService:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = "Unknown"
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.timezone is not None
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("App model config not found")
|
||||
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"meta": {
|
||||
"status": "success",
|
||||
"executor": executor,
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -10,6 +8,8 @@ from werkzeug.exceptions import NotFound
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
@ -24,6 +24,7 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -40,7 +41,7 @@ class AppAnnotationService:
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
annotation: MessageAnnotation | None = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args["answer"]
|
||||
@ -62,6 +63,7 @@ class AppAnnotationService:
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
assert current_user.current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
@ -70,10 +72,10 @@ class AppAnnotationService:
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return cast(MessageAnnotation, annotation)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -84,6 +86,8 @@ class AppAnnotationService:
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
@ -96,7 +100,9 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -113,6 +119,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -145,6 +153,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -164,6 +174,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -193,6 +205,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -230,6 +244,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -246,11 +262,9 @@ class AppAnnotationService:
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
annotation_hit_histories = db.session.scalars(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
).all()
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
@ -269,6 +283,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -282,7 +298,7 @@ class AppAnnotationService:
|
||||
annotations_to_delete = (
|
||||
db.session.query(MessageAnnotation, AppAnnotationSetting)
|
||||
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
|
||||
.filter(MessageAnnotation.id.in_(annotation_ids))
|
||||
.where(MessageAnnotation.id.in_(annotation_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -315,8 +331,10 @@ class AppAnnotationService:
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -328,9 +346,9 @@ class AppAnnotationService:
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file, dtype=str)
|
||||
df = pd.read_csv(file.stream, dtype=str)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
for _, row in df.iterrows():
|
||||
content = {"question": row.iloc[0], "answer": row.iloc[1]}
|
||||
result.append(content)
|
||||
if len(result) == 0:
|
||||
@ -355,6 +373,8 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -425,6 +445,8 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -438,19 +460,29 @@ class AppAnnotationService:
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
if collection_binding_detail:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {},
|
||||
}
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -479,21 +511,31 @@ class AppAnnotationService:
|
||||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
if collection_binding_detail:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str) -> dict:
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class APIBasedExtensionService:
|
||||
return extension_data
|
||||
|
||||
@staticmethod
|
||||
def delete(extension_data: APIBasedExtension) -> None:
|
||||
def delete(extension_data: APIBasedExtension):
|
||||
db.session.delete(extension_data)
|
||||
db.session.commit()
|
||||
|
||||
@ -51,7 +51,7 @@ class APIBasedExtensionService:
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
def _validation(cls, extension_data: APIBasedExtension) -> None:
|
||||
def _validation(cls, extension_data: APIBasedExtension):
|
||||
# name
|
||||
if not extension_data.name:
|
||||
raise ValueError("name must not be empty")
|
||||
@ -95,7 +95,7 @@ class APIBasedExtensionService:
|
||||
cls._ping_connection(extension_data)
|
||||
|
||||
@staticmethod
|
||||
def _ping_connection(extension_data: APIBasedExtension) -> None:
|
||||
def _ping_connection(extension_data: APIBasedExtension):
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -17,6 +16,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
@ -42,7 +42,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.3.1"
|
||||
CURRENT_DSL_VERSION = "0.4.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
@ -60,8 +60,8 @@ class ImportStatus(StrEnum):
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
app_id: Optional[str] = None
|
||||
app_mode: Optional[str] = None
|
||||
app_id: str | None = None
|
||||
app_mode: str | None = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
@ -98,17 +98,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
class PendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
app_id: str | None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesPendingData(BaseModel):
|
||||
dependencies: list[PluginDependency]
|
||||
app_id: str | None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
class AppDslService:
|
||||
@ -120,14 +120,14 @@ class AppDslService:
|
||||
*,
|
||||
account: Account,
|
||||
import_mode: str,
|
||||
yaml_content: Optional[str] = None,
|
||||
yaml_url: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
icon_type: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_background: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
yaml_content: str | None = None,
|
||||
yaml_url: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
icon_type: str | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
app_id: str | None = None,
|
||||
) -> Import:
|
||||
"""Import an app from YAML content or URL."""
|
||||
import_id = str(uuid.uuid4())
|
||||
@ -406,15 +406,15 @@ class AppDslService:
|
||||
def _create_or_update_app(
|
||||
self,
|
||||
*,
|
||||
app: Optional[App],
|
||||
app: App | None,
|
||||
data: dict,
|
||||
account: Account,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
icon_type: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_background: Optional[str] = None,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
icon_type: str | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
dependencies: list[PluginDependency] | None = None,
|
||||
) -> App:
|
||||
"""Create a new app or update an existing one."""
|
||||
app_data = data.get("app", {})
|
||||
@ -532,7 +532,7 @@ class AppDslService:
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
|
||||
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
|
||||
"""
|
||||
Export app
|
||||
:param app_model: App instance
|
||||
@ -556,7 +556,7 @@ class AppDslService:
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
cls._append_workflow_export_data(
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
|
||||
)
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
@ -564,14 +564,16 @@ class AppDslService:
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
|
||||
def _append_workflow_export_data(
|
||||
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
|
||||
):
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
:param app_model: App instance
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model)
|
||||
workflow = workflow_service.get_draft_workflow(app_model, workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
@ -606,7 +608,7 @@ class AppDslService:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
@ -784,7 +786,10 @@ class AppDslService:
|
||||
|
||||
@classmethod
|
||||
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||
"""Encrypt dataset_id using AES-CBC mode"""
|
||||
"""Encrypt dataset_id using AES-CBC mode or return plain text based on configuration"""
|
||||
if not dify_config.DSL_EXPORT_ENCRYPT_DATASET_ID:
|
||||
return dataset_id
|
||||
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
@ -793,12 +798,34 @@ class AppDslService:
|
||||
|
||||
@classmethod
|
||||
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
"""AES decryption"""
|
||||
"""AES decryption with fallback to plain text UUID"""
|
||||
# First, check if it's already a plain UUID (not encrypted)
|
||||
if cls._is_valid_uuid(encrypted_data):
|
||||
return encrypted_data
|
||||
|
||||
# If it's not a UUID, try to decrypt it
|
||||
try:
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||
return pt.decode()
|
||||
decrypted_text = pt.decode()
|
||||
|
||||
# Validate that the decrypted result is a valid UUID
|
||||
if cls._is_valid_uuid(decrypted_text):
|
||||
return decrypted_text
|
||||
else:
|
||||
# If decrypted result is not a valid UUID, it's probably not our encrypted data
|
||||
return None
|
||||
except Exception:
|
||||
# If decryption fails completely, return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_uuid(value: str) -> bool:
|
||||
"""Check if string is a valid UUID format"""
|
||||
try:
|
||||
uuid.UUID(value)
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
@ -55,12 +55,12 @@ class AppGenerateService:
|
||||
cls.system_rate_limiter.increment_rate_limit(app_model.tenant_id)
|
||||
|
||||
# app level rate limiter
|
||||
max_active_request = AppGenerateService._get_max_active_requests(app_model)
|
||||
max_active_request = cls._get_max_active_requests(app_model)
|
||||
rate_limit = RateLimit(app_model.id, max_active_request)
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
@ -69,7 +69,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
@ -78,7 +78,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
elif app_model.mode == AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
@ -87,7 +87,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
@ -103,7 +103,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
@ -155,14 +155,14 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
@ -174,14 +174,14 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
@ -214,7 +214,7 @@ class AppGenerateService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
|
||||
"""
|
||||
Get workflow
|
||||
:param app_model: app model
|
||||
@ -227,7 +227,7 @@ class AppGenerateService:
|
||||
# If workflow_id is specified, get the specific workflow version
|
||||
if workflow_id:
|
||||
try:
|
||||
workflow_uuid = uuid.UUID(workflow_id)
|
||||
_ = uuid.UUID(workflow_id)
|
||||
except ValueError:
|
||||
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
|
||||
|
||||
@ -6,7 +6,7 @@ from models.model import AppMode
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, TypedDict, cast
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from configs import dify_config
|
||||
@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
@ -25,6 +25,8 @@ from services.feature_service import FeatureService
|
||||
from services.tag_service import TagService
|
||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||
@ -38,15 +40,15 @@ class AppService:
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode == AppMode.WORKFLOW.value)
|
||||
filters.append(App.mode == AppMode.WORKFLOW)
|
||||
elif args["mode"] == "completion":
|
||||
filters.append(App.mode == AppMode.COMPLETION.value)
|
||||
filters.append(App.mode == AppMode.COMPLETION)
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode == AppMode.CHAT.value)
|
||||
filters.append(App.mode == AppMode.CHAT)
|
||||
elif args["mode"] == "advanced-chat":
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
@ -94,8 +96,8 @@ class AppService:
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
model_instance = None
|
||||
except Exception as e:
|
||||
logging.exception("Get default model instance failed, tenant_id: %s", tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Get default model instance failed, tenant_id: %s", tenant_id)
|
||||
model_instance = None
|
||||
|
||||
if model_instance:
|
||||
@ -166,9 +168,13 @@ class AppService:
|
||||
"""
|
||||
Get App
|
||||
"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get original app model config
|
||||
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
|
||||
model_config = app.app_model_config
|
||||
if not model_config:
|
||||
return app
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
@ -199,11 +205,12 @@ class AppService:
|
||||
|
||||
# override tool parameters
|
||||
tool["tool_parameters"] = masked_parameter
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
if model_config:
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
class ModifiedApp(App):
|
||||
"""
|
||||
@ -237,6 +244,7 @@ class AppService:
|
||||
:param args: request args
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.name = args["name"]
|
||||
app.description = args["description"]
|
||||
app.icon_type = args["icon_type"]
|
||||
@ -257,6 +265,7 @@ class AppService:
|
||||
:param name: new name
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.name = name
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = naive_utc_now()
|
||||
@ -272,6 +281,7 @@ class AppService:
|
||||
:param icon_background: new icon_background
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background
|
||||
app.updated_by = current_user.id
|
||||
@ -289,7 +299,7 @@ class AppService:
|
||||
"""
|
||||
if enable_site == app.enable_site:
|
||||
return app
|
||||
|
||||
assert current_user is not None
|
||||
app.enable_site = enable_site
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = naive_utc_now()
|
||||
@ -306,6 +316,7 @@ class AppService:
|
||||
"""
|
||||
if enable_api == app.enable_api:
|
||||
return app
|
||||
assert current_user is not None
|
||||
|
||||
app.enable_api = enable_api
|
||||
app.updated_by = current_user.id
|
||||
@ -314,7 +325,7 @@ class AppService:
|
||||
|
||||
return app
|
||||
|
||||
def delete_app(self, app: App) -> None:
|
||||
def delete_app(self, app: App):
|
||||
"""
|
||||
Delete app
|
||||
:param app: App instance
|
||||
@ -329,7 +340,7 @@ class AppService:
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
def get_app_meta(self, app_model: App):
|
||||
"""
|
||||
Get app meta info
|
||||
:param app_model: app model
|
||||
@ -359,7 +370,7 @@ class AppService:
|
||||
}
|
||||
)
|
||||
else:
|
||||
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
|
||||
app_model_config: AppModelConfig | None = app_model.app_model_config
|
||||
|
||||
if not app_model_config:
|
||||
return meta
|
||||
@ -382,7 +393,7 @@ class AppService:
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: Optional[ApiToolProvider] = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
|
||||
)
|
||||
if provider is None:
|
||||
|
||||
@ -2,7 +2,6 @@ import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
@ -12,7 +11,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from models.model import App, AppMode, Message
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
@ -40,7 +39,9 @@ class AudioService:
|
||||
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
if not app_model_config.speech_to_text_dict["enabled"]:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
@ -75,18 +76,18 @@ class AudioService:
|
||||
def transcript_tts(
|
||||
cls,
|
||||
app_model: App,
|
||||
text: Optional[str] = None,
|
||||
voice: Optional[str] = None,
|
||||
end_user: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
text: str | None = None,
|
||||
voice: str | None = None,
|
||||
end_user: str | None = None,
|
||||
message_id: str | None = None,
|
||||
is_draft: bool = False,
|
||||
):
|
||||
from app import app
|
||||
|
||||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
|
||||
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
||||
with app.app_context():
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
else:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
@ -8,12 +10,12 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
.all()
|
||||
)
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
@ -70,10 +70,10 @@ class BillingService:
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_tenant_owner_or_admin(current_user):
|
||||
def is_tenant_owner_or_admin(current_user: Account):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join: Optional[TenantAccountJoin] = (
|
||||
join: TenantAccountJoin | None = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
|
||||
@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -34,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
|
||||
"""
|
||||
Clean up message-related tables to avoid data redundancy.
|
||||
This method cleans up tables that have foreign key relationships with Message.
|
||||
@ -62,7 +63,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.filter(
|
||||
.where(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
@ -101,7 +102,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).filter(
|
||||
session.query(model).where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
|
||||
with flask_app.app_context():
|
||||
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
|
||||
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
@ -295,7 +296,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
@ -321,9 +322,9 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).filter(
|
||||
WorkflowAppLog.id.in_(workflow_app_log_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
@ -353,7 +354,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
try:
|
||||
if (
|
||||
not dify_config.BILLING_ENABLED
|
||||
@ -407,6 +408,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
|
||||
@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
def get_code_based_extension(module: str):
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [
|
||||
{
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -36,12 +36,12 @@ class ConversationService:
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
user: Union[Account, EndUser] | None,
|
||||
last_id: str | None,
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[Sequence[str]] = None,
|
||||
exclude_ids: Optional[Sequence[str]] = None,
|
||||
include_ids: Sequence[str] | None = None,
|
||||
exclude_ids: Sequence[str] | None = None,
|
||||
sort_by: str = "-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
@ -118,7 +118,7 @@ class ConversationService:
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
user: Union[Account, EndUser] | None,
|
||||
name: str,
|
||||
auto_generate: bool,
|
||||
):
|
||||
@ -158,7 +158,7 @@ class ConversationService:
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(
|
||||
@ -178,7 +178,7 @@ class ConversationService:
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
try:
|
||||
logger.info(
|
||||
"Initiating conversation deletion for app_name %s, conversation_id: %s",
|
||||
@ -200,9 +200,9 @@ class ConversationService:
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
user: Union[Account, EndUser] | None,
|
||||
limit: int,
|
||||
last_id: Optional[str],
|
||||
last_id: str | None,
|
||||
) -> InfiniteScrollPagination:
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
@ -222,8 +222,8 @@ class ConversationService:
|
||||
# Filter for variables created after the last_id
|
||||
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
|
||||
|
||||
# Apply limit to query
|
||||
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
|
||||
# Apply limit to query: fetch one extra row to determine has_more
|
||||
query_stmt = stmt.limit(limit + 1)
|
||||
rows = session.scalars(query_stmt).all()
|
||||
|
||||
has_more = False
|
||||
@ -248,9 +248,9 @@ class ConversationService:
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
variable_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
user: Union[Account, EndUser] | None,
|
||||
new_value: Any,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a conversation variable's value.
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@ import secrets
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Literal, Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -27,6 +28,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -76,6 +78,8 @@ from tasks.remove_document_from_index_task import remove_document_from_index_tas
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
@ -131,7 +135,14 @@ class DatasetService:
|
||||
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if tag_ids and len(tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
|
||||
if tenant_id is not None:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge",
|
||||
tenant_id,
|
||||
tag_ids,
|
||||
)
|
||||
else:
|
||||
target_ids = []
|
||||
if target_ids and len(target_ids) > 0:
|
||||
query = query.where(Dataset.id.in_(target_ids))
|
||||
else:
|
||||
@ -174,16 +185,16 @@ class DatasetService:
|
||||
def create_empty_dataset(
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
description: Optional[str],
|
||||
indexing_technique: Optional[str],
|
||||
description: str | None,
|
||||
indexing_technique: str | None,
|
||||
account: Account,
|
||||
permission: Optional[str] = None,
|
||||
permission: str | None = None,
|
||||
provider: str = "vendor",
|
||||
external_knowledge_api_id: Optional[str] = None,
|
||||
external_knowledge_id: Optional[str] = None,
|
||||
embedding_model_provider: Optional[str] = None,
|
||||
embedding_model_name: Optional[str] = None,
|
||||
retrieval_model: Optional[RetrievalModel] = None,
|
||||
external_knowledge_api_id: str | None = None,
|
||||
external_knowledge_id: str | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
@ -210,7 +221,7 @@ class DatasetService:
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
# check if reranking model setting is valid
|
||||
DatasetService.check_embedding_model_setting(
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
@ -246,8 +257,8 @@ class DatasetService:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@ -492,8 +503,11 @@ class DatasetService:
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
@ -605,8 +619,12 @@ class DatasetService:
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
@ -615,7 +633,7 @@ class DatasetService:
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, preserve existing settings
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Failed to initialize embedding model %s/%s, preserving existing settings",
|
||||
data["embedding_model_provider"],
|
||||
data["embedding_model"],
|
||||
@ -653,19 +671,17 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_permission(dataset, user):
|
||||
if dataset.tenant_id != user.current_tenant_id:
|
||||
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
# For partial team permission, user needs explicit permission or be the creator
|
||||
@ -674,11 +690,11 @@ class DatasetService:
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
)
|
||||
if not user_permission:
|
||||
logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
|
||||
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
@ -715,7 +731,9 @@ class DatasetService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
return {
|
||||
@ -724,14 +742,12 @@ class DatasetService:
|
||||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog)
|
||||
.where(
|
||||
dataset_auto_disable_logs = db.session.scalars(
|
||||
select(DatasetAutoDisableLog).where(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if dataset_auto_disable_logs:
|
||||
return {
|
||||
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
||||
@ -852,7 +868,7 @@ class DocumentService:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
|
||||
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
|
||||
if document_id:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
@ -862,73 +878,64 @@ class DocumentService:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Optional[Document]:
|
||||
def get_document_by_id(document_id: str) -> Document | None:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
.all()
|
||||
)
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
assert isinstance(current_user, Account)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.batch == batch,
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@ -965,13 +972,14 @@ class DocumentService:
|
||||
# Check if document_ids is not empty to avoid WHERE false condition
|
||||
if not document_ids or len(document_ids) == 0:
|
||||
return
|
||||
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
|
||||
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
|
||||
file_ids = [
|
||||
document.data_source_info_dict["upload_file_id"]
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file"
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
]
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
@ -979,6 +987,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
@ -994,7 +1004,7 @@ class DocumentService:
|
||||
if dataset.built_in_field_enabled:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = name
|
||||
doc_metadata[BuiltInField.document_name] = name
|
||||
document.doc_metadata = doc_metadata
|
||||
|
||||
document.name = name
|
||||
@ -1008,6 +1018,7 @@ class DocumentService:
|
||||
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be paused
|
||||
assert current_user is not None
|
||||
document.is_paused = True
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
@ -1063,8 +1074,9 @@ class DocumentService:
|
||||
# sync document indexing
|
||||
document.indexing_status = "waiting"
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["mode"] = "scrape"
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
if data_source_info:
|
||||
data_source_info["mode"] = "scrape"
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -1087,12 +1099,15 @@ class DocumentService:
|
||||
dataset: Dataset,
|
||||
knowledge_config: KnowledgeConfig,
|
||||
account: Account | Any,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
) -> tuple[list[Document], str]:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
# check document limit
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
@ -1149,7 +1164,7 @@ class DocumentService:
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -1190,11 +1205,11 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Invalid process rule mode: %s, can not find dataset process rule",
|
||||
process_rule.mode,
|
||||
)
|
||||
return
|
||||
return [], ""
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
@ -1429,6 +1444,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
@ -1446,9 +1463,11 @@ class DocumentService:
|
||||
dataset: Dataset,
|
||||
document_data: KnowledgeConfig,
|
||||
account: Account,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
|
||||
if document is None:
|
||||
@ -1508,7 +1527,7 @@ class DocumentService:
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
db.and_(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
@ -1569,6 +1588,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
@ -1612,7 +1634,7 @@ class DocumentService:
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=2,
|
||||
top_k=4,
|
||||
score_threshold_enabled=False,
|
||||
)
|
||||
# save dataset
|
||||
@ -1882,7 +1904,7 @@ class DocumentService:
|
||||
task_func.delay(*task_args)
|
||||
except Exception as e:
|
||||
# Log the error but do not rollback the transaction
|
||||
logging.exception("Error executing async task for document %s", update_info["document"].id)
|
||||
logger.exception("Error executing async task for document %s", update_info["document"].id)
|
||||
# don't raise the error immediately, but capture it for later
|
||||
propagation_error = e
|
||||
try:
|
||||
@ -1893,7 +1915,7 @@ class DocumentService:
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
except Exception as e:
|
||||
# Log the error but do not rollback the transaction
|
||||
logging.exception("Error setting cache for document %s", update_info["document"].id)
|
||||
logger.exception("Error setting cache for document %s", update_info["document"].id)
|
||||
# Raise any propagation error after all updates
|
||||
if propagation_error:
|
||||
raise propagation_error
|
||||
@ -2008,6 +2030,9 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
content = args["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
@ -2059,7 +2084,7 @@ class SegmentService:
|
||||
try:
|
||||
VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
@ -2070,6 +2095,9 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
|
||||
increment_word_count = 0
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -2142,7 +2170,7 @@ class SegmentService:
|
||||
# save vector index
|
||||
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
logger.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
@ -2153,6 +2181,9 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
@ -2314,7 +2345,7 @@ class SegmentService:
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("update segment index failed")
|
||||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
@ -2334,7 +2365,22 @@ class SegmentService:
|
||||
if segment.enabled:
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
|
||||
|
||||
# Get child chunk IDs before parent segment is deleted
|
||||
child_node_ids = []
|
||||
if segment.index_node_id:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
||||
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
@ -2344,9 +2390,14 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
|
||||
.filter(
|
||||
assert current_user is not None
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
segments_info = (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
|
||||
.where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
@ -2355,16 +2406,36 @@ class SegmentService:
|
||||
.all()
|
||||
)
|
||||
|
||||
if not segments:
|
||||
if not segments_info:
|
||||
return
|
||||
|
||||
index_node_ids = [seg.index_node_id for seg in segments]
|
||||
total_words = sum(seg.word_count for seg in segments)
|
||||
index_node_ids = [info[0] for info in segments_info]
|
||||
segment_db_ids = [info[1] for info in segments_info]
|
||||
total_words = sum(info[2] for info in segments_info if info[2] is not None)
|
||||
|
||||
document.word_count -= total_words
|
||||
# Get child chunk IDs before parent segments are deleted
|
||||
child_node_ids = []
|
||||
if index_node_ids:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
||||
|
||||
document.word_count = (
|
||||
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
|
||||
)
|
||||
db.session.add(document)
|
||||
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
|
||||
# Delete database records
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
|
||||
db.session.commit()
|
||||
|
||||
@ -2372,20 +2443,20 @@ class SegmentService:
|
||||
def update_segments_status(
|
||||
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
|
||||
):
|
||||
assert current_user is not None
|
||||
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
@ -2403,16 +2474,14 @@ class SegmentService:
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
@ -2434,20 +2503,12 @@ class SegmentService:
|
||||
def create_child_chunk(
|
||||
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
|
||||
) -> ChildChunk:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
lock_name = f"add_child_lock_{segment.id}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
index_node_id = str(uuid.uuid4())
|
||||
index_node_hash = helper.generate_text_hash(content)
|
||||
child_chunk_count = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.document_id == document.id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
max_position = (
|
||||
db.session.query(func.max(ChildChunk.position))
|
||||
.where(
|
||||
@ -2476,7 +2537,7 @@ class SegmentService:
|
||||
try:
|
||||
VectorService.create_child_chunk_vector(child_chunk, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create child chunk index failed")
|
||||
logger.exception("create child chunk index failed")
|
||||
db.session.rollback()
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
db.session.commit()
|
||||
@ -2491,15 +2552,14 @@ class SegmentService:
|
||||
document: Document,
|
||||
dataset: Dataset,
|
||||
) -> list[ChildChunk]:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
assert isinstance(current_user, Account)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.document_id == document.id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
|
||||
|
||||
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
|
||||
@ -2551,7 +2611,7 @@ class SegmentService:
|
||||
VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("update child chunk index failed")
|
||||
logger.exception("update child chunk index failed")
|
||||
db.session.rollback()
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
|
||||
@ -2565,6 +2625,8 @@ class SegmentService:
|
||||
document: Document,
|
||||
dataset: Dataset,
|
||||
) -> ChildChunk:
|
||||
assert current_user is not None
|
||||
|
||||
try:
|
||||
child_chunk.content = content
|
||||
child_chunk.word_count = len(content)
|
||||
@ -2575,7 +2637,7 @@ class SegmentService:
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("update child chunk index failed")
|
||||
logger.exception("update child chunk index failed")
|
||||
db.session.rollback()
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return child_chunk
|
||||
@ -2586,15 +2648,17 @@ class SegmentService:
|
||||
try:
|
||||
VectorService.delete_child_chunk_vector(child_chunk, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("delete child chunk index failed")
|
||||
logger.exception("delete child chunk index failed")
|
||||
db.session.rollback()
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_child_chunks(
|
||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
query = (
|
||||
select(ChildChunk)
|
||||
.filter_by(
|
||||
@ -2610,7 +2674,7 @@ class SegmentService:
|
||||
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@classmethod
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
|
||||
"""Get a child chunk by its ID."""
|
||||
result = (
|
||||
db.session.query(ChildChunk)
|
||||
@ -2647,57 +2711,7 @@ class SegmentService:
|
||||
return paginated_segments.items, paginated_segments.total
|
||||
|
||||
@classmethod
|
||||
def update_segment_by_id(
|
||||
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
|
||||
) -> tuple[DocumentSegment, Document]:
|
||||
"""Update a segment by its ID with validation and checks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check embedding model setting if high quality
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=user_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
# check segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate and update segment
|
||||
cls.segment_create_args_validate(segment_data, document)
|
||||
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
|
||||
|
||||
return updated_segment, document
|
||||
|
||||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
|
||||
"""Get a segment by its ID."""
|
||||
result = (
|
||||
db.session.query(DocumentSegment)
|
||||
@ -2755,19 +2769,13 @@ class DatasetCollectionBindingService:
|
||||
class DatasetPermissionService:
|
||||
@classmethod
|
||||
def get_dataset_partial_member_list(cls, dataset_id):
|
||||
user_list_query = (
|
||||
db.session.query(
|
||||
user_list_query = db.session.scalars(
|
||||
select(
|
||||
DatasetPermission.account_id,
|
||||
)
|
||||
.where(DatasetPermission.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
).where(DatasetPermission.dataset_id == dataset_id)
|
||||
).all()
|
||||
|
||||
user_list = []
|
||||
for user in user_list_query:
|
||||
user_list.append(user.account_id)
|
||||
|
||||
return user_list
|
||||
return user_list_query
|
||||
|
||||
@classmethod
|
||||
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
|
||||
|
||||
@ -3,18 +3,30 @@ import os
|
||||
import requests
|
||||
|
||||
|
||||
class EnterpriseRequest:
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
|
||||
class BaseRequest:
|
||||
proxies = {
|
||||
"http": "",
|
||||
"https": "",
|
||||
}
|
||||
base_url = ""
|
||||
secret_key = ""
|
||||
secret_key_header = ""
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
|
||||
return response.json()
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
secret_key_header = "Enterprise-Api-Secret-Key"
|
||||
|
||||
|
||||
class EnterprisePluginManagerRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
|
||||
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
|
||||
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"
|
||||
|
||||
57
api/services/enterprise/plugin_manager_service.py
Normal file
57
api/services/enterprise/plugin_manager_service.py
Normal file
@ -0,0 +1,57 @@
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.enterprise.base import EnterprisePluginManagerRequest
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginCredentialType(enum.IntEnum):
|
||||
MODEL = enum.auto()
|
||||
TOOL = enum.auto()
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
dify_credential_id: str
|
||||
provider: str
|
||||
credential_type: PluginCredentialType
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
data = super().model_dump(**kwargs)
|
||||
data["credential_type"] = self.credential_type.to_number()
|
||||
return data
|
||||
|
||||
|
||||
class CredentialPolicyViolationError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class PluginManagerService:
|
||||
@classmethod
|
||||
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
|
||||
try:
|
||||
ret = EnterprisePluginManagerRequest.send_request(
|
||||
"POST", "/check-credential-policy-compliance", json=body.model_dump()
|
||||
)
|
||||
if not isinstance(ret, dict) or "result" not in ret:
|
||||
raise ValueError("Invalid response format from plugin manager API")
|
||||
except Exception as e:
|
||||
raise CredentialPolicyViolationError(
|
||||
f"error occurred while checking credential policy compliance: {e}"
|
||||
) from e
|
||||
|
||||
if not ret.get("result", False):
|
||||
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")
|
||||
|
||||
logger.debug(
|
||||
"Credential policy compliance checked for %s with credential %s, result: %s",
|
||||
body.provider,
|
||||
body.dify_credential_id,
|
||||
ret.get("result", False),
|
||||
)
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel):
|
||||
|
||||
class Authorization(BaseModel):
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: Optional[AuthorizationConfig] = None
|
||||
config: AuthorizationConfig | None = None
|
||||
|
||||
|
||||
class ProcessStatusSetting(BaseModel):
|
||||
@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
|
||||
class ExternalKnowledgeApiSetting(BaseModel):
|
||||
url: str
|
||||
request_method: str
|
||||
headers: Optional[dict] = None
|
||||
params: Optional[dict] = None
|
||||
headers: dict | None = None
|
||||
params: dict | None = None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -11,14 +11,14 @@ class ParentMode(StrEnum):
|
||||
|
||||
class NotionIcon(BaseModel):
|
||||
type: str
|
||||
url: Optional[str] = None
|
||||
emoji: Optional[str] = None
|
||||
url: str | None = None
|
||||
emoji: str | None = None
|
||||
|
||||
|
||||
class NotionPage(BaseModel):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: Optional[NotionIcon] = None
|
||||
page_icon: NotionIcon | None = None
|
||||
type: str
|
||||
|
||||
|
||||
@ -40,9 +40,9 @@ class FileInfo(BaseModel):
|
||||
|
||||
class InfoList(BaseModel):
|
||||
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
|
||||
notion_info_list: Optional[list[NotionInfo]] = None
|
||||
file_info_list: Optional[FileInfo] = None
|
||||
website_info_list: Optional[WebsiteInfo] = None
|
||||
notion_info_list: list[NotionInfo] | None = None
|
||||
file_info_list: FileInfo | None = None
|
||||
website_info_list: WebsiteInfo | None = None
|
||||
|
||||
|
||||
class DataSource(BaseModel):
|
||||
@ -61,20 +61,20 @@ class Segmentation(BaseModel):
|
||||
|
||||
|
||||
class Rule(BaseModel):
|
||||
pre_processing_rules: Optional[list[PreProcessingRule]] = None
|
||||
segmentation: Optional[Segmentation] = None
|
||||
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
|
||||
subchunk_segmentation: Optional[Segmentation] = None
|
||||
pre_processing_rules: list[PreProcessingRule] | None = None
|
||||
segmentation: Segmentation | None = None
|
||||
parent_mode: Literal["full-doc", "paragraph"] | None = None
|
||||
subchunk_segmentation: Segmentation | None = None
|
||||
|
||||
|
||||
class ProcessRule(BaseModel):
|
||||
mode: Literal["automatic", "custom", "hierarchical"]
|
||||
rules: Optional[Rule] = None
|
||||
rules: Rule | None = None
|
||||
|
||||
|
||||
class RerankingModel(BaseModel):
|
||||
reranking_provider_name: Optional[str] = None
|
||||
reranking_model_name: Optional[str] = None
|
||||
reranking_provider_name: str | None = None
|
||||
reranking_model_name: str | None = None
|
||||
|
||||
|
||||
class WeightVectorSetting(BaseModel):
|
||||
@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel):
|
||||
|
||||
|
||||
class WeightModel(BaseModel):
|
||||
weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
|
||||
vector_setting: Optional[WeightVectorSetting] = None
|
||||
keyword_setting: Optional[WeightKeywordSetting] = None
|
||||
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
|
||||
vector_setting: WeightVectorSetting | None = None
|
||||
keyword_setting: WeightKeywordSetting | None = None
|
||||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
reranking_enable: bool
|
||||
reranking_model: Optional[RerankingModel] = None
|
||||
reranking_mode: Optional[str] = None
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
score_threshold: Optional[float] = None
|
||||
weights: Optional[WeightModel] = None
|
||||
score_threshold: float | None = None
|
||||
weights: WeightModel | None = None
|
||||
|
||||
|
||||
class MetaDataConfig(BaseModel):
|
||||
@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel):
|
||||
|
||||
|
||||
class KnowledgeConfig(BaseModel):
|
||||
original_document_id: Optional[str] = None
|
||||
original_document_id: str | None = None
|
||||
duplicate: bool = True
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
data_source: Optional[DataSource] = None
|
||||
process_rule: Optional[ProcessRule] = None
|
||||
retrieval_model: Optional[RetrievalModel] = None
|
||||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: Optional[str] = None
|
||||
embedding_model_provider: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class SegmentUpdateArgs(BaseModel):
|
||||
content: Optional[str] = None
|
||||
answer: Optional[str] = None
|
||||
keywords: Optional[list[str]] = None
|
||||
content: str | None = None
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: Optional[bool] = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
content: str
|
||||
|
||||
|
||||
@ -143,13 +143,13 @@ class MetadataArgs(BaseModel):
|
||||
|
||||
class MetadataUpdateArgs(BaseModel):
|
||||
name: str
|
||||
value: Optional[str | int | float] = None
|
||||
value: str | int | float | None = None
|
||||
|
||||
|
||||
class MetadataDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value: Optional[str | int | float] = None
|
||||
value: str | int | float | None = None
|
||||
|
||||
|
||||
class DocumentMetadataOperation(BaseModel):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@ -8,7 +7,13 @@ from core.entities.model_entities import (
|
||||
ModelWithProviderEntity,
|
||||
ProviderModelWithStatusEntity,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomModelConfiguration,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
@ -36,6 +41,11 @@ class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] | None = None
|
||||
custom_models: list[CustomModelConfiguration] | None = None
|
||||
can_added_models: list[UnaddedModelConfiguration] | None = None
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
@ -44,7 +54,7 @@ class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
|
||||
|
||||
@ -56,15 +66,15 @@ class ProviderResponse(BaseModel):
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
preferred_provider_type: ProviderType
|
||||
custom_configuration: CustomConfigurationResponse
|
||||
system_configuration: SystemConfigurationResponse
|
||||
@ -72,7 +82,7 @@ class ProviderResponse(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -97,12 +107,12 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -126,7 +136,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -163,7 +173,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
|
||||
dump_model = model.model_dump()
|
||||
dump_model["provider"]["tenant_id"] = tenant_id
|
||||
super().__init__(**dump_model)
|
||||
|
||||
@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError
|
||||
|
||||
class AppModelConfigBrokenError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BaseServiceError(ValueError):
|
||||
def __init__(self, description: Optional[str] = None):
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import (
|
||||
@ -88,7 +89,7 @@ class ExternalDatasetService:
|
||||
raise ValueError(f"invalid endpoint: {endpoint}")
|
||||
try:
|
||||
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
|
||||
if response.status_code == 502:
|
||||
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
|
||||
@ -99,7 +100,7 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
@ -108,13 +109,14 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
||||
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
|
||||
settings = args.get("settings")
|
||||
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
|
||||
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
|
||||
|
||||
external_knowledge_api.name = args.get("name")
|
||||
external_knowledge_api.description = args.get("description", "")
|
||||
@ -149,7 +151,7 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
@ -179,19 +181,29 @@ class ExternalDatasetService:
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"url": settings.url,
|
||||
"headers": settings.headers,
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
|
||||
data=json.dumps(settings.params), files=files, **kwargs
|
||||
)
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = settings.request_method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {settings.request_method}")
|
||||
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](data=json.dumps(settings.params), files=files, **kwargs)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
|
||||
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
|
||||
authorization = deepcopy(authorization)
|
||||
if headers:
|
||||
headers = deepcopy(headers)
|
||||
@ -218,7 +230,7 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
|
||||
return ExternalKnowledgeApiSetting.parse_obj(settings)
|
||||
return ExternalKnowledgeApiSetting.model_validate(settings)
|
||||
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
@ -265,8 +277,8 @@ class ExternalDatasetService:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
|
||||
@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel):
|
||||
subscription_plan: str = ""
|
||||
|
||||
|
||||
class PluginManagerModel(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel):
|
||||
webapp_auth: WebAppAuthModel = WebAppAuthModel()
|
||||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
|
||||
|
||||
class FeatureService:
|
||||
@ -188,6 +193,7 @@ class FeatureService:
|
||||
system_features.branding.enabled = True
|
||||
system_features.webapp_auth.enabled = True
|
||||
system_features.enable_change_email = False
|
||||
system_features.plugin_manager.enabled = True
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from flask_login import current_user
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -19,6 +18,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@ -35,7 +35,7 @@ class FileService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
mimetype: str,
|
||||
user: Union[Account, EndUser, Any],
|
||||
user: Union[Account, EndUser],
|
||||
source: Literal["datasets"] | None = None,
|
||||
source_url: str = "",
|
||||
) -> UploadFile:
|
||||
@ -111,6 +111,9 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
|
||||
@ -12,11 +12,13 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -31,7 +33,7 @@ class HitTestingService:
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
):
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
@ -64,7 +66,7 @@ class HitTestingService:
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
@ -77,7 +79,7 @@ class HitTestingService:
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug("Hit testing retrieve in %s seconds", end - start)
|
||||
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
|
||||
@ -96,7 +98,7 @@ class HitTestingService:
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
):
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
"query": {"content": query},
|
||||
@ -113,7 +115,7 @@ class HitTestingService:
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug("External knowledge hit testing retrieve in %s seconds", end - start)
|
||||
logger.debug("External knowledge hit testing retrieve in %s seconds", end - start)
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -29,9 +29,9 @@ class MessageService:
|
||||
def pagination_by_first_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
user: Union[Account, EndUser] | None,
|
||||
conversation_id: str,
|
||||
first_id: Optional[str],
|
||||
first_id: str | None,
|
||||
limit: int,
|
||||
order: str = "asc",
|
||||
) -> InfiniteScrollPagination:
|
||||
@ -91,11 +91,11 @@ class MessageService:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
user: Union[Account, EndUser] | None,
|
||||
last_id: str | None,
|
||||
limit: int,
|
||||
conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None,
|
||||
conversation_id: str | None = None,
|
||||
include_ids: list | None = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
@ -112,7 +112,9 @@ class MessageService:
|
||||
base_query = base_query.where(Message.conversation_id == conversation.id)
|
||||
|
||||
# Check if include_ids is not None and not empty to avoid WHERE false condition
|
||||
if include_ids is not None and len(include_ids) > 0:
|
||||
if include_ids is not None:
|
||||
if len(include_ids) == 0:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
base_query = base_query.where(Message.id.in_(include_ids))
|
||||
|
||||
if last_id:
|
||||
@ -143,9 +145,9 @@ class MessageService:
|
||||
*,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
rating: Optional[str],
|
||||
content: Optional[str],
|
||||
user: Union[Account, EndUser] | None,
|
||||
rating: str | None,
|
||||
content: str | None,
|
||||
):
|
||||
if not user:
|
||||
raise ValueError("user cannot be None")
|
||||
@ -194,7 +196,7 @@ class MessageService:
|
||||
return [record.to_dict() for record in feedbacks]
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
@ -214,7 +216,7 @@ class MessageService:
|
||||
|
||||
@classmethod
|
||||
def get_suggested_questions_after_answer(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
|
||||
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
|
||||
) -> list[Message]:
|
||||
if not user:
|
||||
raise ValueError("user cannot be None")
|
||||
@ -227,7 +229,7 @@ class MessageService:
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_service = WorkflowService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@ -15,6 +14,8 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataOperationData,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataService:
|
||||
@staticmethod
|
||||
@ -90,7 +91,7 @@ class MetadataService:
|
||||
db.session.commit()
|
||||
return metadata # type: ignore
|
||||
except Exception:
|
||||
logging.exception("Update metadata name failed")
|
||||
logger.exception("Update metadata name failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@ -122,18 +123,18 @@ class MetadataService:
|
||||
db.session.commit()
|
||||
return metadata
|
||||
except Exception:
|
||||
logging.exception("Delete metadata failed")
|
||||
logger.exception("Delete metadata failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name.value, "type": "string"},
|
||||
{"name": BuiltInField.uploader.value, "type": "string"},
|
||||
{"name": BuiltInField.upload_date.value, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date.value, "type": "time"},
|
||||
{"name": BuiltInField.source.value, "type": "string"},
|
||||
{"name": BuiltInField.document_name, "type": "string"},
|
||||
{"name": BuiltInField.uploader, "type": "string"},
|
||||
{"name": BuiltInField.upload_date, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date, "type": "time"},
|
||||
{"name": BuiltInField.source, "type": "string"},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -151,17 +152,17 @@ class MetadataService:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
doc_metadata[BuiltInField.document_name] = document.name
|
||||
doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Enable built-in field failed")
|
||||
logger.exception("Enable built-in field failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@ -181,18 +182,18 @@ class MetadataService:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.last_update_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.source.value, None)
|
||||
doc_metadata.pop(BuiltInField.document_name, None)
|
||||
doc_metadata.pop(BuiltInField.uploader, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date, None)
|
||||
doc_metadata.pop(BuiltInField.last_update_date, None)
|
||||
doc_metadata.pop(BuiltInField.source, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Disable built-in field failed")
|
||||
logger.exception("Disable built-in field failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@ -209,11 +210,11 @@ class MetadataService:
|
||||
for metadata_value in operation.metadata_list:
|
||||
doc_metadata[metadata_value.name] = metadata_value.value
|
||||
if dataset.built_in_field_enabled:
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
doc_metadata[BuiltInField.document_name] = document.name
|
||||
doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -230,12 +231,12 @@ class MetadataService:
|
||||
db.session.add(dataset_metadata_binding)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Update documents metadata failed")
|
||||
logger.exception("Update documents metadata failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
|
||||
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
|
||||
if dataset_id:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
if redis_client.get(lock_key):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
@ -17,16 +19,16 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model load balancing.
|
||||
|
||||
@ -47,7 +49,7 @@ class ModelLoadBalancingService:
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model load balancing.
|
||||
|
||||
@ -69,7 +71,7 @@ class ModelLoadBalancingService:
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
@ -100,6 +102,11 @@ class ModelLoadBalancingService:
|
||||
if provider_model_setting and provider_model_setting.load_balancing_enabled:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
@ -108,6 +115,10 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
@ -154,7 +165,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
try:
|
||||
if load_balancing_config.encrypted_config:
|
||||
credentials = json.loads(load_balancing_config.encrypted_config)
|
||||
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
|
||||
else:
|
||||
credentials = {}
|
||||
except JSONDecodeError:
|
||||
@ -169,9 +180,13 @@ class ModelLoadBalancingService:
|
||||
for variable in credential_secret_variables:
|
||||
if variable in credentials:
|
||||
try:
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
|
||||
)
|
||||
token_value = credentials.get(variable)
|
||||
if isinstance(token_value, str):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
token_value,
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@ -185,6 +200,7 @@ class ModelLoadBalancingService:
|
||||
"id": load_balancing_config.id,
|
||||
"name": load_balancing_config.name,
|
||||
"credentials": credentials,
|
||||
"credential_id": load_balancing_config.credential_id,
|
||||
"enabled": load_balancing_config.enabled,
|
||||
"in_cooldown": in_cooldown,
|
||||
"ttl": ttl,
|
||||
@ -195,7 +211,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def get_load_balancing_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
|
||||
) -> Optional[dict]:
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get load balancing configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -280,8 +296,8 @@ class ModelLoadBalancingService:
|
||||
return inherit_config
|
||||
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
|
||||
) -> None:
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
|
||||
):
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -289,6 +305,7 @@ class ModelLoadBalancingService:
|
||||
:param model: model name
|
||||
:param model_type: model type
|
||||
:param configs: load balancing configs
|
||||
:param config_from: predefined-model or custom-model
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
@ -305,16 +322,14 @@ class ModelLoadBalancingService:
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
|
||||
current_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.where(
|
||||
current_load_balancing_configs = db.session.scalars(
|
||||
select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# id as key, config as value
|
||||
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
|
||||
@ -327,8 +342,38 @@ class ModelLoadBalancingService:
|
||||
config_id = config.get("id")
|
||||
name = config.get("name")
|
||||
credentials = config.get("credentials")
|
||||
credential_id = config.get("credential_id")
|
||||
enabled = config.get("enabled")
|
||||
|
||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||
|
||||
if credential_id:
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
credential_record = (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.filter_by(
|
||||
id=credential_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not credential_record:
|
||||
raise ValueError(f"Provider credential with id {credential_id} not found")
|
||||
name = credential_record.credential_name
|
||||
|
||||
if not name:
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
@ -346,11 +391,6 @@ class ModelLoadBalancingService:
|
||||
|
||||
load_balancing_config = current_load_balancing_configs_dict[config_id]
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
|
||||
raise ValueError(f"Load balancing config name {name} already exists")
|
||||
|
||||
if credentials:
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
@ -380,36 +420,45 @@ class ModelLoadBalancingService:
|
||||
if name == "__inherit__":
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.name == name:
|
||||
raise ValueError(f"Load balancing config name {name} already exists")
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=credential_record.credential_name,
|
||||
encrypted_config=credential_record.encrypted_config,
|
||||
credential_id=credential_id,
|
||||
credential_source_type=credential_source,
|
||||
)
|
||||
else:
|
||||
if not credentials:
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# create load balancing config
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
# create load balancing config
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(load_balancing_model_config)
|
||||
db.session.commit()
|
||||
@ -429,8 +478,8 @@ class ModelLoadBalancingService:
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None,
|
||||
) -> None:
|
||||
config_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Validate load balancing credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -487,9 +536,9 @@ class ModelLoadBalancingService:
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
load_balancing_model_config: LoadBalancingModelConfig | None = None,
|
||||
validate: bool = True,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -557,7 +606,7 @@ class ModelLoadBalancingService:
|
||||
else:
|
||||
raise ValueError("No credential schema found")
|
||||
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
|
||||
"""
|
||||
Clear credentials cache.
|
||||
:param tenant_id: workspace id
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
@ -16,6 +15,7 @@ from services.entities.model_provider_entities import (
|
||||
SimpleProviderEntityResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from services.errors.app_model_config import ProviderNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -25,10 +25,33 @@ class ModelProviderService:
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
|
||||
def _get_provider_configuration(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
Get provider configuration or raise exception if not found.
|
||||
|
||||
Args:
|
||||
tenant_id: Workspace identifier
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
Provider configuration instance
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If provider doesn't exist
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
||||
if not provider_configuration:
|
||||
raise ProviderNotFoundError(f"Provider {provider} does not exist.")
|
||||
|
||||
return provider_configuration
|
||||
|
||||
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
|
||||
"""
|
||||
get provider list.
|
||||
|
||||
@ -46,6 +69,10 @@ class ModelProviderService:
|
||||
if model_type_entity not in provider_configuration.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
@ -63,7 +90,12 @@ class ModelProviderService:
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
@ -82,8 +114,8 @@ class ModelProviderService:
|
||||
For the model provider page,
|
||||
only supports passing in a single provider to query the list of supported models.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
@ -95,100 +127,109 @@ class ModelProviderService:
|
||||
for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
|
||||
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
|
||||
"""
|
||||
get provider credentials.
|
||||
"""
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
return provider_configuration.get_custom_credentials(obfuscated=True)
|
||||
|
||||
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:param credentials:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
provider_configuration.custom_credentials_validate(credentials)
|
||||
|
||||
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
save custom provider config.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: credential id, if not provided, return current used credentials
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom provider credentials.
|
||||
provider_configuration.add_or_update_custom_credentials(credentials)
|
||||
|
||||
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
"""
|
||||
remove custom provider config.
|
||||
validate provider credentials before saving.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.create_provider_credential(credentials, credential_name)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
|
||||
def update_provider_credential(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
get model credentials.
|
||||
update a saved provider credential (by credential_id).
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials dict
|
||||
:param credential_id: credential id
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.update_provider_credential(
|
||||
credential_id=credential_id,
|
||||
credentials=credentials,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
remove a saved provider credential (by credential_id).
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_provider_credential(credential_id=credential_id)
|
||||
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieve model-specific credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: Optional credential ID, uses current if not provided
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model custom credentials from ProviderModel if exists
|
||||
return provider_configuration.get_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def model_credentials_validate(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@ -196,49 +237,120 @@ class ModelProviderService:
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Validate model credentials
|
||||
provider_configuration.custom_model_credentials_validate(
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def save_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
save model credentials.
|
||||
create and save model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param credentials: model credentials dict
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Add or update custom model credentials
|
||||
provider_configuration.add_or_update_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.create_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
def update_model_credential(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials dict
|
||||
:param credential_id: credential id
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.update_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def switch_active_custom_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
):
|
||||
"""
|
||||
switch model credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
):
|
||||
"""
|
||||
add model credentials to model list.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.add_model_credential_to_model(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
@ -248,16 +360,8 @@ class ModelProviderService:
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom model credentials
|
||||
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
@ -271,7 +375,7 @@ class ModelProviderService:
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
|
||||
|
||||
# Group models by provider
|
||||
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
|
||||
@ -282,9 +386,6 @@ class ModelProviderService:
|
||||
if model.deprecated:
|
||||
continue
|
||||
|
||||
if model.status != ModelStatus.ACTIVE:
|
||||
continue
|
||||
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
@ -331,13 +432,7 @@ class ModelProviderService:
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
@ -351,7 +446,7 @@ class ModelProviderService:
|
||||
|
||||
return model_schema.parameter_rules if model_schema else []
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
@ -383,7 +478,7 @@ class ModelProviderService:
|
||||
logger.debug("get_default_model_of_model_type error: %s", e)
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
@ -400,7 +495,7 @@ class ModelProviderService:
|
||||
|
||||
def get_model_provider_icon(
|
||||
self, tenant_id: str, provider: str, icon_type: str, lang: str
|
||||
) -> tuple[Optional[bytes], Optional[str]]:
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
@ -415,7 +510,7 @@ class ModelProviderService:
|
||||
|
||||
return byte_data, mime_type
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
@ -424,21 +519,15 @@ class ModelProviderService:
|
||||
:param preferred_provider_type: preferred provider type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
|
||||
# Convert preferred_provider_type to ProviderType
|
||||
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Switch preferred provider type
|
||||
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
|
||||
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model.
|
||||
|
||||
@ -448,18 +537,10 @@ class ModelProviderService:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model.
|
||||
|
||||
@ -469,13 +550,5 @@ class ModelProviderService:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
94
api/services/oauth_server.py
Normal file
94
api/services/oauth_server.py
Normal file
@ -0,0 +1,94 @@
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class OAuthGrantType(enum.StrEnum):
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
|
||||
|
||||
|
||||
class OAuthServerService:
|
||||
@staticmethod
|
||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
return session.execute(query).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
|
||||
code = str(uuid.uuid4())
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_access_token(
|
||||
grant_type: OAuthGrantType,
|
||||
code: str = "",
|
||||
client_id: str = "",
|
||||
refresh_token: str = "",
|
||||
) -> tuple[str, str]:
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid code")
|
||||
|
||||
# delete code
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid refresh token")
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
return None
|
||||
|
||||
user_id_str = user_account_id.decode("utf-8")
|
||||
|
||||
return AccountService.load_user(user_id_str)
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
@ -15,7 +15,7 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
@ -134,17 +134,26 @@ class OpsService:
|
||||
|
||||
# get project url
|
||||
if tracing_provider in ("arize", "phoenix"):
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
except Exception:
|
||||
project_url = None
|
||||
elif tracing_provider == "langfuse":
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
||||
try:
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
||||
except Exception:
|
||||
project_url = None
|
||||
elif tracing_provider in ("langsmith", "opik"):
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
except Exception:
|
||||
project_url = None
|
||||
else:
|
||||
project_url = None
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
|
||||
@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
def migrate(cls):
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
@ -26,7 +26,7 @@ class PluginDataMigration:
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
def migrate_datasets(cls):
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
@ -126,9 +126,7 @@ limit 1000"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(
|
||||
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||
) -> None:
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
@ -175,7 +173,7 @@ limit 1000"""
|
||||
# update jina to langgenius/jina_tool/jina etc.
|
||||
updated_value = provider_cls(provider_name).to_string()
|
||||
batch_updates.append((updated_value, record_id))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -10,7 +10,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -26,7 +26,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
@ -54,7 +54,7 @@ class PluginAutoUpgradeService:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
@ -33,7 +33,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
@ -55,7 +55,7 @@ class PluginMigration:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
@ -99,6 +99,7 @@ class PluginMigration:
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
@ -255,7 +256,7 @@ class PluginMigration:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
@ -280,7 +281,7 @@ class PluginMigration:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
@ -291,7 +292,7 @@ class PluginMigration:
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
@ -328,7 +329,7 @@ class PluginMigration:
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
@ -348,7 +349,7 @@ class PluginMigration:
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
def install(tenant_id: str, plugin_ids: list[str]):
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -46,11 +45,11 @@ class PluginService:
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
Fetch the latest plugin version
|
||||
"""
|
||||
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
|
||||
result: dict[str, PluginService.LatestPluginCache | None] = {}
|
||||
|
||||
try:
|
||||
cache_not_exists = []
|
||||
@ -109,7 +108,7 @@ class PluginService:
|
||||
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
|
||||
|
||||
@staticmethod
|
||||
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
|
||||
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
|
||||
"""
|
||||
Check the plugin installation scope
|
||||
"""
|
||||
@ -144,7 +143,7 @@ class PluginService:
|
||||
return manager.get_debugging_key(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
|
||||
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
List the latest versions of the plugins
|
||||
"""
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@ -14,12 +13,12 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
|
||||
"""
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
builtin_data: dict | None = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.BUILDIN
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
@ -28,7 +27,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
def _get_builtin_data(cls):
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
@ -44,7 +43,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from builtin.
|
||||
:param language: language
|
||||
@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return builtin_data.get("recommended_apps", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch recommended app detail from builtin.
|
||||
:param app_id: App ID
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@ -25,24 +25,20 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_db(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
.all()
|
||||
)
|
||||
recommended_apps = db.session.scalars(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
).all()
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
recommended_apps = db.session.scalars(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
@ -74,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
|
||||
@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
|
||||
"""Interface for recommend app retrieval."""
|
||||
|
||||
@abstractmethod
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
@ -24,7 +23,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
|
||||
return result
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
try:
|
||||
result = self.fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
|
||||
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch recommended app detail from dify official.
|
||||
:param app_id: App ID
|
||||
@ -51,7 +50,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from dify official.
|
||||
:param language: language
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
|
||||
|
||||
class RecommendedAppService:
|
||||
@classmethod
|
||||
def get_recommended_apps_and_categories(cls, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(cls, language: str):
|
||||
"""
|
||||
Get recommended apps and categories.
|
||||
:param language: language
|
||||
@ -15,7 +13,7 @@ class RecommendedAppService:
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result = retrieval_instance.get_recommended_apps_and_categories(language)
|
||||
if not result.get("recommended_apps") and language != "en-US":
|
||||
if not result.get("recommended_apps"):
|
||||
result = (
|
||||
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
|
||||
"en-US"
|
||||
@ -25,7 +23,7 @@ class RecommendedAppService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
|
||||
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
|
||||
"""
|
||||
Get recommend app detail.
|
||||
:param app_id: app id
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
@ -11,7 +11,7 @@ from services.message_service import MessageService
|
||||
class SavedMessageService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
|
||||
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
raise ValueError("User is required")
|
||||
@ -32,7 +32,7 @@ class SavedMessageService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
@ -62,7 +62,7 @@ class SavedMessageService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@ -25,46 +24,41 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
tags = db.session.scalars(
|
||||
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding.target_id)
|
||||
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
tag_bindings = db.session.scalars(
|
||||
select(TagBinding.target_id).where(
|
||||
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).all()
|
||||
return tag_bindings
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
tags = list(
|
||||
db.session.scalars(
|
||||
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@ -117,7 +111,7 @@ class TagService:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
|
||||
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -443,9 +444,7 @@ class ApiToolManageService:
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
|
||||
@ -3,8 +3,9 @@ import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -190,11 +191,14 @@ class BuiltinToolManageService:
|
||||
# update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
@ -219,8 +223,8 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -246,11 +250,14 @@ class BuiltinToolManageService:
|
||||
)
|
||||
else:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
@ -278,9 +285,9 @@ class BuiltinToolManageService:
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -453,7 +460,7 @@ class BuiltinToolManageService:
|
||||
check if oauth system client exists
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider_name)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
@ -467,7 +474,7 @@ class BuiltinToolManageService:
|
||||
check if oauth custom client is enabled
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@ -492,7 +499,7 @@ class BuiltinToolManageService:
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@ -546,65 +553,64 @@ class BuiltinToolManageService:
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
@ -659,8 +665,8 @@ class BuiltinToolManageService:
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: Optional[dict] = None,
|
||||
enable_oauth_custom_client: Optional[bool] = None,
|
||||
client_params: dict | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -27,6 +27,36 @@ class MCPToolManageService:
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of headers to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
Dictionary with all headers encrypted
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return cast(dict[str, str], encrypter_instance.encrypt(headers))
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
@ -61,6 +91,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
@ -83,6 +114,12 @@ class MCPToolManageService:
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@ -95,6 +132,7 @@ class MCPToolManageService:
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
@ -118,9 +156,21 @@ class MCPToolManageService:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = mcp_provider.decrypted_server_url
|
||||
authed = mcp_provider.authed
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=authed,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError:
|
||||
raise ValueError("Please auth the tool first")
|
||||
@ -172,6 +222,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
@ -207,6 +258,13 @@ class MCPToolManageService:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if headers is not None:
|
||||
# Encrypt headers
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
@ -226,10 +284,10 @@ class MCPToolManageService:
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
@ -242,6 +300,12 @@ class MCPToolManageService:
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
|
||||
# Get the existing provider to access headers and timeout settings
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
@ -249,6 +313,9 @@ class MCPToolManageService:
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -94,7 +94,7 @@ class ToolTransformService:
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
db_provider: BuiltinToolProvider | None,
|
||||
decrypt_credentials: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
@ -128,7 +128,7 @@ class ToolTransformService:
|
||||
)
|
||||
}
|
||||
|
||||
for name, value in schema.items():
|
||||
for name in schema:
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
|
||||
@ -237,6 +237,10 @@ class ToolTransformService:
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
server_identifier=db_provider.server_identifier,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
masked_headers=db_provider.masked_headers,
|
||||
original_headers=db_provider.decrypted_headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -37,7 +37,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
@ -103,7 +103,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -186,7 +186,9 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
@ -217,7 +219,7 @@ class WorkflowToolManageService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -233,7 +235,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -249,7 +251,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -265,7 +267,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -13,13 +12,13 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorService:
|
||||
@classmethod
|
||||
def create_segments_vector(
|
||||
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
):
|
||||
documents: list[Document] = []
|
||||
|
||||
@ -27,7 +26,7 @@ class VectorService:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||
segment.document_id,
|
||||
segment.id,
|
||||
@ -79,7 +78,7 @@ class VectorService:
|
||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
|
||||
# update segment index task
|
||||
|
||||
# format new index
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -19,11 +19,11 @@ class WebConversationService:
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
user: Union[Account, EndUser] | None,
|
||||
last_id: str | None,
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None,
|
||||
pinned: bool | None = None,
|
||||
sort_by="-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
@ -60,7 +60,7 @@ class WebConversationService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
@ -92,7 +92,7 @@ class WebConversationService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
@ -42,7 +42,7 @@ class WebAppAuthService:
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
return cast(Account, account)
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def login(cls, account: Account) -> str:
|
||||
@ -63,7 +63,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def send_email_code_login_email(
|
||||
cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
|
||||
cls, account: Account | None = None, email: str | None = None, language: str = "en-US"
|
||||
):
|
||||
email = account.email if account else email
|
||||
if email is None:
|
||||
@ -82,7 +82,7 @@ class WebAppAuthService:
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_code_login")
|
||||
|
||||
@classmethod
|
||||
@ -113,7 +113,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def _get_account_jwt_token(cls, account: Account) -> str:
|
||||
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
|
||||
exp = int(exp_dt.timestamp())
|
||||
|
||||
payload = {
|
||||
@ -130,7 +130,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def is_app_require_permission_check(
|
||||
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
|
||||
cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the app requires permission check based on its access mode.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
@ -21,9 +21,9 @@ class CrawlOptions:
|
||||
limit: int = 1
|
||||
crawl_sub_pages: bool = False
|
||||
only_main_content: bool = False
|
||||
includes: Optional[str] = None
|
||||
excludes: Optional[str] = None
|
||||
max_depth: Optional[int] = None
|
||||
includes: str | None = None
|
||||
excludes: str | None = None
|
||||
max_depth: int | None = None
|
||||
use_sitemap: bool = True
|
||||
|
||||
def get_include_paths(self) -> list[str]:
|
||||
@ -132,7 +132,7 @@ class WebsiteService:
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict) -> None:
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -18,6 +18,7 @@ from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.nodes import NodeType
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -64,7 +65,7 @@ class WorkflowConverter:
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name or app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
|
||||
new_app.icon_type = icon_type or app_model.icon_type
|
||||
new_app.icon = icon or app_model.icon
|
||||
new_app.icon_background = icon_background or app_model.icon_background
|
||||
@ -202,7 +203,7 @@ class WorkflowConverter:
|
||||
app_mode_enum = AppMode.value_of(app_model.mode)
|
||||
app_config: EasyUIBasedAppConfig
|
||||
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_model.mode = AppMode.AGENT_CHAT
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
@ -217,7 +218,7 @@ class WorkflowConverter:
|
||||
|
||||
return app_config
|
||||
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]):
|
||||
"""
|
||||
Convert to Start Node
|
||||
:param variables: list of variables
|
||||
@ -278,7 +279,7 @@ class WorkflowConverter:
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
"inputs": inputs,
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
|
||||
},
|
||||
}
|
||||
|
||||
@ -326,7 +327,7 @@ class WorkflowConverter:
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(
|
||||
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
|
||||
) -> Optional[dict]:
|
||||
) -> dict | None:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
@ -382,9 +383,9 @@ class WorkflowConverter:
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileUploadConfig] = None,
|
||||
file_upload: FileUploadConfig | None = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@ -402,7 +403,7 @@ class WorkflowConverter:
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
prompts: Optional[Any] = None
|
||||
prompts: Any | None = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
@ -420,7 +421,11 @@ class WorkflowConverter:
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||
|
||||
template = prompt_template_obj.template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
@ -457,7 +462,11 @@ class WorkflowConverter:
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
prompt_template_obj = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template_obj, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
|
||||
|
||||
template = prompt_template_obj.template
|
||||
template = self._replace_template_variables(
|
||||
template=template,
|
||||
variables=start_node["data"]["variables"],
|
||||
@ -467,6 +476,9 @@ class WorkflowConverter:
|
||||
prompts = {"text": template}
|
||||
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
if not isinstance(prompt_rules, dict):
|
||||
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||
|
||||
role_prefix = {
|
||||
"user": prompt_rules.get("human_prefix", "Human"),
|
||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
@ -550,7 +562,7 @@ class WorkflowConverter:
|
||||
|
||||
return template
|
||||
|
||||
def _convert_to_end_node(self) -> dict:
|
||||
def _convert_to_end_node(self):
|
||||
"""
|
||||
Convert to End Node
|
||||
:return:
|
||||
@ -566,7 +578,7 @@ class WorkflowConverter:
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
def _convert_to_answer_node(self):
|
||||
"""
|
||||
Convert to Answer Node
|
||||
:return:
|
||||
@ -578,7 +590,7 @@ class WorkflowConverter:
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
def _create_edge(self, source: str, target: str):
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
@ -587,7 +599,7 @@ class WorkflowConverter:
|
||||
"""
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
def _append_node(self, graph: dict, node: dict):
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
@ -606,7 +618,7 @@ class WorkflowConverter:
|
||||
:param app_model: App instance
|
||||
:return: AppMode
|
||||
"""
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return AppMode.WORKFLOW
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
@ -23,7 +23,7 @@ class WorkflowAppService:
|
||||
limit: int = 20,
|
||||
created_by_end_user_session_id: str | None = None,
|
||||
created_by_account: str | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Get paginate workflow app logs using SQLAlchemy 2.0 style
|
||||
:param session: SQLAlchemy session
|
||||
|
||||
@ -28,7 +28,7 @@ from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -67,7 +67,7 @@ class DraftVarLoader(VariableLoader):
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
fallback_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
self._tenant_id = tenant_id
|
||||
@ -117,7 +117,7 @@ class DraftVarLoader(VariableLoader):
|
||||
class WorkflowDraftVariableService:
|
||||
_session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
def __init__(self, session: Session):
|
||||
"""
|
||||
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
|
||||
|
||||
@ -242,7 +242,7 @@ class WorkflowDraftVariableService:
|
||||
if conv_var is None:
|
||||
self._session.delete(instance=variable)
|
||||
self._session.flush()
|
||||
_logger.warning(
|
||||
logger.warning(
|
||||
"Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name
|
||||
)
|
||||
return None
|
||||
@ -263,12 +263,12 @@ class WorkflowDraftVariableService:
|
||||
if variable.node_execution_id is None:
|
||||
self._session.delete(instance=variable)
|
||||
self._session.flush()
|
||||
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
|
||||
logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
|
||||
return None
|
||||
|
||||
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
|
||||
if node_exec is None:
|
||||
_logger.warning(
|
||||
logger.warning(
|
||||
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
|
||||
variable.id,
|
||||
variable.name,
|
||||
@ -351,7 +351,7 @@ class WorkflowDraftVariableService:
|
||||
return None
|
||||
segment = draft_var.get_value()
|
||||
if not isinstance(segment, StringSegment):
|
||||
_logger.warning(
|
||||
logger.warning(
|
||||
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
|
||||
app_id,
|
||||
draft_var.id,
|
||||
@ -438,7 +438,7 @@ def _batch_upsert_draft_variable(
|
||||
session: Session,
|
||||
draft_vars: Sequence[WorkflowDraftVariable],
|
||||
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
|
||||
) -> None:
|
||||
):
|
||||
if not draft_vars:
|
||||
return None
|
||||
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
|
||||
@ -681,7 +681,7 @@ class DraftVariableSaver:
|
||||
draft_vars = []
|
||||
for name, value in output.items():
|
||||
if not self._should_variable_be_saved(name):
|
||||
_logger.debug(
|
||||
logger.debug(
|
||||
"Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s",
|
||||
name,
|
||||
self._node_type,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -80,7 +79,7 @@ class WorkflowRunService:
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
||||
|
||||
@ -2,10 +2,10 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
@ -37,23 +37,15 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .chatflow_memory_service import ChatflowMemoryService
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from .workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
@ -89,20 +81,21 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
def is_workflow_exist(self, app_model: App) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
stmt = select(
|
||||
exists().where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.count()
|
||||
) > 0
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
if workflow_id:
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id)
|
||||
# fetch draft workflow by app_model
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
@ -117,8 +110,10 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
# fetch published workflow by workflow_id
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
"""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
@ -137,7 +132,7 @@ class WorkflowService:
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
@ -202,7 +197,7 @@ class WorkflowService:
|
||||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
@ -270,6 +265,12 @@ class WorkflowService:
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# Validate credentials before publishing, for credential policy check
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if FeatureService.get_system_features().plugin_manager.enabled:
|
||||
self._validate_workflow_credentials(draft_workflow)
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@ -294,6 +295,260 @@ class WorkflowService:
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
|
||||
"""
|
||||
Validate all credentials in workflow nodes before publishing.
|
||||
|
||||
:param workflow: The workflow to validate
|
||||
:raises ValueError: If any credentials violate policy compliance
|
||||
"""
|
||||
graph_dict = workflow.graph_dict
|
||||
nodes = graph_dict.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
node_type = node_data.get("type")
|
||||
node_id = node.get("id", "unknown")
|
||||
|
||||
try:
|
||||
# Extract and validate credentials based on node type
|
||||
if node_type == "tool":
|
||||
credential_id = node_data.get("credential_id")
|
||||
provider = node_data.get("provider_id")
|
||||
if provider:
|
||||
if credential_id:
|
||||
# Check specific credential
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=credential_id,
|
||||
provider=provider,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
)
|
||||
else:
|
||||
# Check default workspace credential for this provider
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
|
||||
elif node_type == "agent":
|
||||
agent_params = node_data.get("agent_parameters", {})
|
||||
|
||||
model_config = agent_params.get("model", {}).get("value", {})
|
||||
if model_config.get("provider") and model_config.get("model"):
|
||||
self._validate_llm_model_config(
|
||||
workflow.tenant_id, model_config["provider"], model_config["model"]
|
||||
)
|
||||
|
||||
# Validate load balancing credentials for agent model if load balancing is enabled
|
||||
agent_model_node_data = {"model": model_config}
|
||||
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
|
||||
|
||||
# Validate agent tools
|
||||
tools = agent_params.get("tools", {}).get("value", [])
|
||||
for tool in tools:
|
||||
# Agent tools store provider in provider_name field
|
||||
provider = tool.get("provider_name")
|
||||
credential_id = tool.get("credential_id")
|
||||
if provider:
|
||||
if credential_id:
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
|
||||
else:
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
|
||||
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
|
||||
model_config = node_data.get("model", {})
|
||||
provider = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
|
||||
if provider and model_name:
|
||||
# Validate that the provider+model combination can fetch valid credentials
|
||||
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
|
||||
# Validate load balancing credentials if load balancing is enabled
|
||||
self._validate_load_balancing_credentials(workflow, node_data, node_id)
|
||||
else:
|
||||
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise e
|
||||
else:
|
||||
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")
|
||||
|
||||
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
|
||||
"""
|
||||
Validate that an LLM model configuration can fetch valid credentials.
|
||||
|
||||
This method attempts to get the model instance and validates that:
|
||||
1. The provider exists and is configured
|
||||
2. The model exists in the provider
|
||||
3. Credentials can be fetched for the model
|
||||
4. The credentials pass policy compliance checks
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
|
||||
"""
|
||||
try:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
# Get model instance to validate provider+model combination
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
|
||||
# The ModelInstance constructor will automatically check credential policy compliance
|
||||
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
|
||||
# If it fails, an exception will be raised
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
|
||||
)
|
||||
|
||||
def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
|
||||
"""
|
||||
Check credential policy compliance for the default workspace credential of a tool provider.
|
||||
|
||||
This method finds the default credential for the given provider and validates it.
|
||||
Uses the same fallback logic as runtime to handle deauthorized credentials.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The tool provider name
|
||||
:raises ValueError: If no default credential exists or if it fails policy compliance
|
||||
"""
|
||||
try:
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
# Use the same fallback logic as runtime: get the first available credential
|
||||
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
|
||||
default_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
raise ValueError("No default credential found")
|
||||
|
||||
# Check credential policy compliance using the default credential ID
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=default_provider.id,
|
||||
provider=provider,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
|
||||
|
||||
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
|
||||
"""
|
||||
Validate load balancing credentials for a workflow node.
|
||||
|
||||
:param workflow: The workflow being validated
|
||||
:param node_data: The node data containing model configuration
|
||||
:param node_id: The node ID for error reporting
|
||||
:raises ValueError: If load balancing credentials violate policy compliance
|
||||
"""
|
||||
# Extract model configuration
|
||||
model_config = node_data.get("model", {})
|
||||
provider = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
|
||||
if not provider or not model_name:
|
||||
return # No model config to validate
|
||||
|
||||
# Check if this model has load balancing enabled
|
||||
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
|
||||
# Get all load balancing configurations for this model
|
||||
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
|
||||
# Validate each load balancing configuration
|
||||
try:
|
||||
for config in load_balancing_configs:
|
||||
if config.get("credential_id"):
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
config["credential_id"], provider, PluginCredentialType.MODEL
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
|
||||
|
||||
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
|
||||
"""
|
||||
Check if load balancing is enabled for a specific model.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:return: True if load balancing is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
# Get provider configurations
|
||||
provider_manager = ProviderManager()
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
||||
if not provider_configuration:
|
||||
return False
|
||||
|
||||
# Get provider model setting
|
||||
provider_model_setting = provider_configuration.get_provider_model_setting(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled
|
||||
|
||||
except Exception:
|
||||
# If we can't determine the status, assume load balancing is not enabled
|
||||
return False
|
||||
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
|
||||
"""
|
||||
Get all load balancing configurations for a model.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
:param model_name: The model name
|
||||
:return: List of load balancing configuration dictionaries
|
||||
"""
|
||||
try:
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
_, configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
model_type="llm", # Load balancing is primarily used for LLM models
|
||||
config_from="predefined-model", # Check both predefined and custom models
|
||||
)
|
||||
|
||||
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
|
||||
)
|
||||
all_configs = configs + custom_configs
|
||||
|
||||
return [config for config in all_configs if config.get("credential_id")]
|
||||
|
||||
except Exception:
|
||||
# If we can't get the configurations, return empty list
|
||||
# This will prevent validation errors from breaking the workflow
|
||||
return []
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configs
|
||||
@ -308,7 +563,7 @@ class WorkflowService:
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
@ -509,10 +764,10 @@ class WorkflowService:
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node = e._node
|
||||
node = e.node
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e._error
|
||||
error = e.error
|
||||
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = WorkflowNodeExecution(
|
||||
@ -568,7 +823,7 @@ class WorkflowService:
|
||||
# chatbot convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
|
||||
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
@ -583,12 +838,12 @@ class WorkflowService:
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
@ -597,7 +852,7 @@ class WorkflowService:
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ class WorkspaceService:
|
||||
def get_tenant_info(cls, tenant: Tenant):
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_info = {
|
||||
tenant_info: dict[str, object] = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
|
||||
Reference in New Issue
Block a user