mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Merge main into feat/plugin
This commit is contained in:
@ -1,3 +1,3 @@
|
||||
from . import errors
|
||||
|
||||
__all__ = ['errors']
|
||||
__all__ = ["errors"]
|
||||
|
||||
@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(
|
||||
prefix="reset_password_rate_limit",
|
||||
max_attempts=5,
|
||||
time_window=60 * 60
|
||||
)
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
@ -55,12 +50,15 @@ class AccountService:
|
||||
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
||||
raise Unauthorized("Account is banned or closed.")
|
||||
|
||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
|
||||
account_id=account.id, current=True
|
||||
).first()
|
||||
if current_tenant:
|
||||
account.current_tenant_id = current_tenant.tenant_id
|
||||
else:
|
||||
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
|
||||
.order_by(TenantAccountJoin.id.asc()).first()
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
|
||||
@ -74,14 +72,13 @@ class AccountService:
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
|
||||
payload = {
|
||||
"user_id": account.id,
|
||||
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
|
||||
"iss": dify_config.EDITION,
|
||||
"sub": 'Console API Passport',
|
||||
"sub": "Console API Passport",
|
||||
}
|
||||
|
||||
token = PassportService().issue(payload)
|
||||
@ -93,10 +90,10 @@ class AccountService:
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
raise AccountLoginError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise AccountLoginError('Account is banned or closed.')
|
||||
raise AccountLoginError("Account is banned or closed.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
@ -104,7 +101,7 @@ class AccountService:
|
||||
db.session.commit()
|
||||
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
raise AccountLoginError("Invalid email or password.")
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -129,11 +126,9 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_account(email: str,
|
||||
name: str,
|
||||
interface_language: str,
|
||||
password: Optional[str] = None,
|
||||
interface_theme: str = 'light') -> Account:
|
||||
def create_account(
|
||||
email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light"
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
account = Account()
|
||||
account.email = email
|
||||
@ -155,7 +150,7 @@ class AccountService:
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
# Set timezone based on language
|
||||
account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
|
||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
@ -166,8 +161,9 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
|
||||
provider=provider).first()
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
||||
account_id=account.id, provider=provider
|
||||
).first()
|
||||
|
||||
if account_integrate:
|
||||
# If it exists, update the record
|
||||
@ -176,15 +172,16 @@ class AccountService:
|
||||
account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
else:
|
||||
# If it does not exist, create a new record
|
||||
account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
|
||||
encrypted_token="")
|
||||
account_integrate = AccountIntegrate(
|
||||
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
|
||||
)
|
||||
db.session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
logging.info(f'Account {account.id} linked {provider} account {open_id}.')
|
||||
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
|
||||
raise LinkAccountIntegrateError('Failed to link account.') from e
|
||||
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
@ -218,7 +215,7 @@ class AccountService:
|
||||
AccountService.update_last_login(account, ip_address=ip_address)
|
||||
exp = timedelta(days=30)
|
||||
token = AccountService.get_account_jwt_token(account, exp=exp)
|
||||
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
|
||||
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
@ -236,22 +233,18 @@ class AccountService:
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
|
||||
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
|
||||
|
||||
token = TokenManager.generate_token(account, 'reset_password')
|
||||
send_reset_password_mail_task.delay(
|
||||
language=account.interface_language,
|
||||
to=account.email,
|
||||
token=token
|
||||
)
|
||||
token = TokenManager.generate_token(account, "reset_password")
|
||||
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, 'reset_password')
|
||||
TokenManager.revoke_token(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, 'reset_password')
|
||||
return TokenManager.get_token_data(token, "reset_password")
|
||||
|
||||
|
||||
def _get_login_cache_key(*, account_id: str, token: str):
|
||||
@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str):
|
||||
|
||||
|
||||
class TenantService:
|
||||
|
||||
@staticmethod
|
||||
def create_tenant(name: str) -> Tenant:
|
||||
"""Create tenant"""
|
||||
@ -275,31 +267,28 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account):
|
||||
"""Create owner tenant if not exist"""
|
||||
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
|
||||
.order_by(TenantAccountJoin.id.asc()).first()
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
return
|
||||
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
db.session.commit()
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@staticmethod
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountJoinRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
|
||||
logging.error(f'Tenant {tenant.id} has already an owner.')
|
||||
raise Exception('Tenant already has an owner.')
|
||||
logging.error(f"Tenant {tenant.id} has already an owner.")
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
ta = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role
|
||||
)
|
||||
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
|
||||
db.session.add(ta)
|
||||
db.session.commit()
|
||||
return ta
|
||||
@ -307,9 +296,12 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return db.session.query(Tenant).join(
|
||||
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
|
||||
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
|
||||
return (
|
||||
db.session.query(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_current_tenant_by_account(account: Account):
|
||||
@ -333,16 +325,23 @@ class TenantService:
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
).first()
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
|
||||
TenantAccountJoin.query.filter(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
@ -354,9 +353,7 @@ class TenantService:
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id)
|
||||
)
|
||||
|
||||
@ -375,11 +372,9 @@ class TenantService:
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id)
|
||||
.filter(TenantAccountJoin.role == 'dataset_operator')
|
||||
.filter(TenantAccountJoin.role == "dataset_operator")
|
||||
)
|
||||
|
||||
# Initialize an empty list to store the updated accounts
|
||||
@ -395,20 +390,25 @@ class TenantService:
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
|
||||
raise ValueError('all roles must be TenantAccountJoinRole')
|
||||
raise ValueError("all roles must be TenantAccountJoinRole")
|
||||
|
||||
return db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.role.in_([role.value for role in roles])
|
||||
).first() is not None
|
||||
return (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == account.id
|
||||
).first()
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.first()
|
||||
)
|
||||
return join.role if join else None
|
||||
|
||||
@staticmethod
|
||||
@ -420,29 +420,26 @@ class TenantService:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
'remove': [TenantAccountRole.OWNER],
|
||||
'update': [TenantAccountRole.OWNER]
|
||||
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
"remove": [TenantAccountRole.OWNER],
|
||||
"update": [TenantAccountRole.OWNER],
|
||||
}
|
||||
if action not in ['add', 'remove', 'update']:
|
||||
if action not in ["add", "remove", "update"]:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
|
||||
if member:
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=operator.id
|
||||
).first()
|
||||
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f'No permission to {action} member.')
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
@ -455,23 +452,17 @@ class TenantService:
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, 'update')
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=member.id
|
||||
).first()
|
||||
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
|
||||
if target_member_join.role == new_role:
|
||||
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
||||
|
||||
if new_role == 'owner':
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
role='owner'
|
||||
).first()
|
||||
current_owner_join.role = 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join.role = "admin"
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
@ -480,8 +471,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
||||
"""Dissolve tenant"""
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
|
||||
raise NoPermissionError('No permission to dissolve tenant.')
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
|
||||
raise NoPermissionError("No permission to dissolve tenant.")
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
@ -494,10 +485,9 @@ class TenantService:
|
||||
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_token_key(cls, token: str) -> str:
|
||||
return f'member_invite:token:{token}'
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
|
||||
@ -523,9 +513,7 @@ class RegisterService:
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
dify_setup = DifySetup(
|
||||
version=dify_config.CURRENT_VERSION
|
||||
)
|
||||
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
@ -535,34 +523,35 @@ class RegisterService:
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
logging.exception(f'Setup failed: {e}')
|
||||
raise ValueError(f'Setup failed: {e}')
|
||||
logging.exception(f"Setup failed: {e}")
|
||||
raise ValueError(f"Setup failed: {e}")
|
||||
|
||||
@classmethod
|
||||
def register(cls, email, name,
|
||||
password: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
status: Optional[AccountStatus] = None) -> Account:
|
||||
def register(
|
||||
cls,
|
||||
email,
|
||||
name,
|
||||
password: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
status: Optional[AccountStatus] = None,
|
||||
) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=language if language else languages[0],
|
||||
password=password
|
||||
email=email, name=name, interface_language=language if language else languages[0], password=password
|
||||
)
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
if open_id is not None or provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
@ -570,30 +559,29 @@ class RegisterService:
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.error(f'Register failed: {e}')
|
||||
raise AccountRegisterError(f'Registration failed: {e}') from e
|
||||
logging.error(f"Register failed: {e}")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
|
||||
) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, 'add')
|
||||
name = email.split('@')[0]
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = email.split("@")[0]
|
||||
|
||||
account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING)
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, 'add')
|
||||
ta = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id
|
||||
).first()
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@ -609,7 +597,7 @@ class RegisterService:
|
||||
language=account.interface_language,
|
||||
to=email,
|
||||
token=token,
|
||||
inviter_name=inviter.name if inviter else 'Dify',
|
||||
inviter_name=inviter.name if inviter else "Dify",
|
||||
workspace_name=tenant.name,
|
||||
)
|
||||
|
||||
@ -619,23 +607,19 @@ class RegisterService:
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
invitation_data = {
|
||||
'account_id': account.id,
|
||||
'email': account.email,
|
||||
'workspace_id': tenant.id,
|
||||
"account_id": account.id,
|
||||
"email": account.email,
|
||||
"workspace_id": tenant.id,
|
||||
}
|
||||
expiryHours = dify_config.INVITE_EXPIRY_HOURS
|
||||
redis_client.setex(
|
||||
cls._get_invitation_token_key(token),
|
||||
expiryHours * 60 * 60,
|
||||
json.dumps(invitation_data)
|
||||
)
|
||||
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
if workspace_id and email:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
redis_client.delete(cls._get_invitation_token_key(token))
|
||||
@ -646,17 +630,21 @@ class RegisterService:
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == invitation_data['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first()
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
@ -665,29 +653,29 @@ class RegisterService:
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if invitation_data['account_id'] != str(account.id):
|
||||
if invitation_data["account_id"] != str(account.id):
|
||||
return None
|
||||
|
||||
return {
|
||||
'account': account,
|
||||
'data': invitation_data,
|
||||
'tenant': tenant,
|
||||
"account": account,
|
||||
"data": invitation_data,
|
||||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]:
|
||||
if workspace_id is not None and email is not None:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}'
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
account_id = redis_client.get(cache_key)
|
||||
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
return {
|
||||
'account_id': account_id.decode('utf-8'),
|
||||
'email': email,
|
||||
'workspace_id': workspace_id,
|
||||
"account_id": account_id.decode("utf-8"),
|
||||
"email": email,
|
||||
"workspace_id": workspace_id,
|
||||
}
|
||||
else:
|
||||
data = redis_client.get(cls._get_invitation_token_key(token))
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
@ -17,59 +16,78 @@ from models.model import AppMode
|
||||
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
app_mode = args['app_mode']
|
||||
model_mode = args['model_mode']
|
||||
model_name = args['model_name']
|
||||
has_context = args['has_context']
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
has_context = args["has_context"]
|
||||
|
||||
if 'baichuan' in model_name.lower():
|
||||
if "baichuan" in model_name.lower():
|
||||
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
|
||||
else:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
|
||||
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
)
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
|
||||
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
)
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
|
||||
@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App,
|
||||
conversation_id: str,
|
||||
message_id: str) -> dict:
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
).first()
|
||||
conversation: Conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
).first()
|
||||
message: Message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
executor = db.session.query(EndUser, EndUser.name).filter(
|
||||
EndUser.id == conversation.from_end_user_id
|
||||
).first()
|
||||
executor = (
|
||||
db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
|
||||
)
|
||||
else:
|
||||
executor = db.session.query(Account, Account.name).filter(
|
||||
Account.id == conversation.from_account_id
|
||||
).first()
|
||||
|
||||
executor = (
|
||||
db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
|
||||
)
|
||||
|
||||
if executor:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = 'Unknown'
|
||||
executor = "Unknown"
|
||||
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
result = {
|
||||
'meta': {
|
||||
'status': 'success',
|
||||
'executor': executor,
|
||||
'start_time': message.created_at.astimezone(timezone).isoformat(),
|
||||
'elapsed_time': message.provider_response_latency,
|
||||
'total_tokens': message.answer_tokens + message.message_tokens,
|
||||
'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'),
|
||||
'iterations': len(agent_thoughts),
|
||||
"meta": {
|
||||
"status": "success",
|
||||
"executor": executor,
|
||||
"start_time": message.created_at.astimezone(timezone).isoformat(),
|
||||
"elapsed_time": message.provider_response_latency,
|
||||
"total_tokens": message.answer_tokens + message.message_tokens,
|
||||
"agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"iterations": len(agent_thoughts),
|
||||
},
|
||||
'iterations': [],
|
||||
'files': message.files,
|
||||
"iterations": [],
|
||||
"files": message.files,
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
@ -86,12 +92,12 @@ class AgentService:
|
||||
tool_input = tool_inputs.get(tool_name, {})
|
||||
tool_output = tool_outputs.get(tool_name, {})
|
||||
tool_meta_data = tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
|
||||
tool_config = tool_meta_data.get("tool_config", {})
|
||||
if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_config.get('tool_provider_type', ''),
|
||||
provider_id=tool_config.get('tool_provider', ''),
|
||||
provider_type=tool_config.get("tool_provider_type", ""),
|
||||
provider_id=tool_config.get("tool_provider", ""),
|
||||
)
|
||||
if not tool_icon:
|
||||
tool_entity = find_agent_tool(tool_name)
|
||||
@ -102,30 +108,34 @@ class AgentService:
|
||||
provider_id=tool_entity.provider_id,
|
||||
)
|
||||
else:
|
||||
tool_icon = ''
|
||||
tool_icon = ""
|
||||
|
||||
tool_calls.append({
|
||||
'status': 'success' if not tool_meta_data.get('error') else 'error',
|
||||
'error': tool_meta_data.get('error'),
|
||||
'time_cost': tool_meta_data.get('time_cost', 0),
|
||||
'tool_name': tool_name,
|
||||
'tool_label': tool_label,
|
||||
'tool_input': tool_input,
|
||||
'tool_output': tool_output,
|
||||
'tool_parameters': tool_meta_data.get('tool_parameters', {}),
|
||||
'tool_icon': tool_icon,
|
||||
})
|
||||
tool_calls.append(
|
||||
{
|
||||
"status": "success" if not tool_meta_data.get("error") else "error",
|
||||
"error": tool_meta_data.get("error"),
|
||||
"time_cost": tool_meta_data.get("time_cost", 0),
|
||||
"tool_name": tool_name,
|
||||
"tool_label": tool_label,
|
||||
"tool_input": tool_input,
|
||||
"tool_output": tool_output,
|
||||
"tool_parameters": tool_meta_data.get("tool_parameters", {}),
|
||||
"tool_icon": tool_icon,
|
||||
}
|
||||
)
|
||||
|
||||
result['iterations'].append({
|
||||
'tokens': agent_thought.tokens,
|
||||
'tool_calls': tool_calls,
|
||||
'tool_raw': {
|
||||
'inputs': agent_thought.tool_input,
|
||||
'outputs': agent_thought.observation,
|
||||
},
|
||||
'thought': agent_thought.thought,
|
||||
'created_at': agent_thought.created_at.isoformat(),
|
||||
'files': agent_thought.files,
|
||||
})
|
||||
result["iterations"].append(
|
||||
{
|
||||
"tokens": agent_thought.tokens,
|
||||
"tool_calls": tool_calls,
|
||||
"tool_raw": {
|
||||
"inputs": agent_thought.tool_input,
|
||||
"outputs": agent_thought.observation,
|
||||
},
|
||||
"thought": agent_thought.thought,
|
||||
"created_at": agent_thought.created_at.isoformat(),
|
||||
"files": agent_thought.files,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@ -23,21 +23,18 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if args.get('message_id'):
|
||||
message_id = str(args['message_id'])
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
# get message info
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -45,159 +42,166 @@ class AppAnnotationService:
|
||||
annotation = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
content=args["answer"],
|
||||
question=args["question"],
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
return {"job_id": cache_result, "job_status": "processing"}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
|
||||
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
|
||||
args['score_threshold'],
|
||||
args['embedding_provider_name'], args['embedding_model_name'])
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args["score_threshold"],
|
||||
args["embedding_provider_name"],
|
||||
args["embedding_model_name"],
|
||||
)
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
return {"job_id": cache_result, "job_status": "processing"}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
|
||||
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
|
||||
MessageAnnotation.content.ilike('%{}%'.format(keyword))
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike("%{}%".format(keyword)),
|
||||
MessageAnnotation.content.ilike("%{}%".format(keyword)),
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
else:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc()).all())
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -207,30 +211,34 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
update_annotation_to_index_task.delay(annotation.id, annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id, app_annotation_setting.collection_binding_id)
|
||||
update_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -242,33 +250,34 @@ class AppAnnotationService:
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(annotation.id, app_id,
|
||||
current_user.current_tenant_id,
|
||||
app_annotation_setting.collection_binding_id)
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -278,10 +287,7 @@ class AppAnnotationService:
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
content = {
|
||||
'question': row[0],
|
||||
'answer': row[1]
|
||||
}
|
||||
content = {"question": row[0], "answer": row[1]}
|
||||
result.append(content)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@ -293,28 +299,24 @@ class AppAnnotationService:
|
||||
raise ValueError("The number of annotations exceeds the limit of your subscription.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_import_annotations_task.delay(
|
||||
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
except Exception as e:
|
||||
return {
|
||||
'error_msg': str(e)
|
||||
}
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
return {"error_msg": str(e)}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -324,12 +326,15 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
@classmethod
|
||||
@ -341,15 +346,21 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
|
||||
annotation_content: str, query: str, user_id: str,
|
||||
message_id: str, from_source: str, score: float):
|
||||
def add_annotation_history(
|
||||
cls,
|
||||
annotation_id: str,
|
||||
app_id: str,
|
||||
annotation_question: str,
|
||||
annotation_content: str,
|
||||
query: str,
|
||||
user_id: str,
|
||||
message_id: str,
|
||||
from_source: str,
|
||||
score: float,
|
||||
):
|
||||
# add hit count to annotation
|
||||
db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.id == annotation_id
|
||||
).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
|
||||
synchronize_session=False
|
||||
db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
|
||||
annotation_hit_history = AppAnnotationHitHistory(
|
||||
@ -361,7 +372,7 @@ class AppAnnotationService:
|
||||
score=score,
|
||||
message_id=message_id,
|
||||
annotation_question=annotation_question,
|
||||
annotation_content=annotation_content
|
||||
annotation_content=annotation_content,
|
||||
)
|
||||
db.session.add(annotation_hit_history)
|
||||
db.session.commit()
|
||||
@ -369,17 +380,18 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
@ -388,32 +400,34 @@ class AppAnnotationService:
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
return {
|
||||
"enabled": False
|
||||
}
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting)
|
||||
.filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not annotation_setting:
|
||||
raise NotFound("App annotation not found")
|
||||
annotation_setting.score_threshold = args['score_threshold']
|
||||
annotation_setting.score_threshold = args["score_threshold"]
|
||||
annotation_setting.updated_user_id = current_user.id
|
||||
annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(annotation_setting)
|
||||
@ -427,6 +441,6 @@ class AppAnnotationService:
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
||||
@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||
extension_list = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.order_by(APIBasedExtension.created_at.desc()) \
|
||||
.all()
|
||||
extension_list = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.order_by(APIBasedExtension.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for extension in extension_list:
|
||||
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||
@ -35,10 +36,12 @@ class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.filter_by(id=api_based_extension_id) \
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter_by(id=api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not extension:
|
||||
raise ValueError("API based extension is not found")
|
||||
@ -55,20 +58,24 @@ class APIBasedExtensionService:
|
||||
|
||||
if not extension_data.id:
|
||||
# case one: check new data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
else:
|
||||
# case two: check existing data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
.filter(APIBasedExtension.id != extension_data.id) \
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.filter(APIBasedExtension.id != extension_data.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
@ -92,7 +99,7 @@ class APIBasedExtensionService:
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
if resp.get('result') != 'pong':
|
||||
if resp.get("result") != "pong":
|
||||
raise ValueError(resp)
|
||||
except Exception as e:
|
||||
raise ValueError("connection error: {}".format(e))
|
||||
|
||||
@ -13,9 +13,9 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_dsl_version = "0.1.0"
|
||||
current_dsl_version = "0.1.1"
|
||||
dsl_to_dify_version_mapping: dict[str, str] = {
|
||||
"0.1.0": "0.6.0", # dsl version -> from dify version
|
||||
"0.1.1": "0.6.0", # dsl version -> from dify version
|
||||
}
|
||||
|
||||
|
||||
@ -75,40 +75,44 @@ class AppDslService:
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get('app')
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
|
||||
# get app basic info
|
||||
name = args.get("name") if args.get("name") else app_data.get('name')
|
||||
description = args.get("description") if args.get("description") else app_data.get('description', '')
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get('icon')
|
||||
icon_background = args.get("icon_background") if args.get("icon_background") \
|
||||
else app_data.get('icon_background')
|
||||
name = args.get("name") if args.get("name") else app_data.get("name")
|
||||
description = args.get("description") if args.get("description") else app_data.get("description", "")
|
||||
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get("icon")
|
||||
icon_background = (
|
||||
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
|
||||
)
|
||||
|
||||
# import dsl and create app
|
||||
app_mode = AppMode.value_of(app_data.get('mode'))
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
app = cls._import_and_create_new_workflow_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
workflow_data=import_data.get('workflow'),
|
||||
workflow_data=import_data.get("workflow"),
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]:
|
||||
app = cls._import_and_create_new_model_config_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
model_config_data=import_data.get('model_config'),
|
||||
model_config_data=import_data.get("model_config"),
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
@ -131,27 +135,26 @@ class AppDslService:
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get('app')
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
|
||||
# import dsl and overwrite app
|
||||
app_mode = AppMode.value_of(app_data.get('mode'))
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
raise ValueError("Only support import workflow in advanced-chat or workflow app.")
|
||||
|
||||
if app_data.get('mode') != app_model.mode:
|
||||
raise ValueError(
|
||||
f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
|
||||
if app_data.get("mode") != app_model.mode:
|
||||
raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
|
||||
|
||||
return cls._import_and_overwrite_workflow_based_app(
|
||||
app_model=app_model,
|
||||
workflow_data=import_data.get('workflow'),
|
||||
workflow_data=import_data.get("workflow"),
|
||||
account=account,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def export_dsl(cls, app_model: App, include_secret:bool = False) -> str:
|
||||
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export app
|
||||
:param app_model: App instance
|
||||
@ -165,18 +168,20 @@ class AppDslService:
|
||||
"app": {
|
||||
"name": app_model.name,
|
||||
"mode": app_model.mode,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"description": app_model.description
|
||||
}
|
||||
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
|
||||
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
|
||||
"description": app_model.description,
|
||||
},
|
||||
}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret)
|
||||
cls._append_workflow_export_data(
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret
|
||||
)
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
return yaml.dump(export_data)
|
||||
return yaml.dump(export_data, allow_unicode=True)
|
||||
|
||||
@classmethod
|
||||
def _check_or_fix_dsl(cls, import_data: dict) -> dict:
|
||||
@ -185,30 +190,35 @@ class AppDslService:
|
||||
|
||||
:param import_data: import data
|
||||
"""
|
||||
if not import_data.get('version'):
|
||||
import_data['version'] = "0.1.0"
|
||||
if not import_data.get("version"):
|
||||
import_data["version"] = "0.1.0"
|
||||
|
||||
if not import_data.get('kind') or import_data.get('kind') != "app":
|
||||
import_data['kind'] = "app"
|
||||
if not import_data.get("kind") or import_data.get("kind") != "app":
|
||||
import_data["kind"] = "app"
|
||||
|
||||
if import_data.get('version') != current_dsl_version:
|
||||
if import_data.get("version") != current_dsl_version:
|
||||
# Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
|
||||
logger.warning(f"DSL version {import_data.get('version')} is not compatible "
|
||||
f"with current version {current_dsl_version}, related to "
|
||||
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.")
|
||||
logger.warning(
|
||||
f"DSL version {import_data.get('version')} is not compatible "
|
||||
f"with current version {current_dsl_version}, related to "
|
||||
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
|
||||
)
|
||||
|
||||
return import_data
|
||||
|
||||
@classmethod
|
||||
def _import_and_create_new_workflow_based_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
workflow_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _import_and_create_new_workflow_based_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
workflow_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Import app dsl and create new workflow based app
|
||||
|
||||
@ -218,12 +228,12 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
:param name: app name
|
||||
:param description: app description
|
||||
:param icon_type: app icon type, "emoji" or "image"
|
||||
:param icon: app icon
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument "
|
||||
"when app mode is advanced-chat or workflow")
|
||||
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -231,35 +241,34 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
|
||||
# init draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables_list = workflow_data.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('../core/app/features', {}),
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("../core/app/features", {}),
|
||||
unique_hash=None,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
draft_workflow=draft_workflow
|
||||
)
|
||||
workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow)
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _import_and_overwrite_workflow_based_app(cls,
|
||||
app_model: App,
|
||||
workflow_data: dict,
|
||||
account: Account) -> Workflow:
|
||||
def _import_and_overwrite_workflow_based_app(
|
||||
cls, app_model: App, workflow_data: dict, account: Account
|
||||
) -> Workflow:
|
||||
"""
|
||||
Import app dsl and overwrite workflow based app
|
||||
|
||||
@ -268,8 +277,7 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument "
|
||||
"when app mode is advanced-chat or workflow")
|
||||
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
@ -280,29 +288,35 @@ class AppDslService:
|
||||
unique_hash = None
|
||||
|
||||
# sync draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables_list = workflow_data.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('features', {}),
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return draft_workflow
|
||||
|
||||
@classmethod
|
||||
def _import_and_create_new_model_config_based_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
model_config_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _import_and_create_new_model_config_based_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
model_config_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Import app dsl and create new model config based app
|
||||
|
||||
@ -316,8 +330,7 @@ class AppDslService:
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
if not model_config_data:
|
||||
raise ValueError("Missing model_config in data argument "
|
||||
"when app mode is chat, agent-chat or completion")
|
||||
raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion")
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -325,35 +338,38 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(model_config_data)
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _create_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _create_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Create new app
|
||||
|
||||
@ -362,6 +378,7 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
:param name: app name
|
||||
:param description: app description
|
||||
:param icon_type: app icon type, "emoji" or "image"
|
||||
:param icon: app icon
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
@ -370,10 +387,13 @@ class AppDslService:
|
||||
mode=app_mode.value,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
enable_site=True,
|
||||
enable_api=True
|
||||
enable_api=True,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(app)
|
||||
@ -395,7 +415,7 @@ class AppDslService:
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data['workflow'] = workflow.to_dict(include_secret=include_secret)
|
||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
@ -408,4 +428,4 @@ class AppDslService:
|
||||
if not app_model_config:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data['model_config'] = app_model_config.to_dict()
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
@ -10,18 +12,20 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate(cls, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
App Content Generate
|
||||
:param app_model: app model
|
||||
@ -37,51 +41,56 @@ class AppGenerateService:
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return rate_limit.generate(CompletionAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
return rate_limit.generate(AgentChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
return rate_limit.generate(ChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming,
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming,
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid app mode {app_model.mode}')
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
except RateLimitError as e:
|
||||
raise InvokeRateLimitError(str(e))
|
||||
finally:
|
||||
if not streaming:
|
||||
rate_limit.exit(request_id)
|
||||
@ -94,38 +103,31 @@ class AppGenerateService:
|
||||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Any,
|
||||
streaming: bool = True):
|
||||
def generate_single_iteration(
|
||||
cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
|
||||
):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
stream=streaming
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
stream=streaming
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid app mode {app_model.mode}')
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[dict, Generator]:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
@ -136,11 +138,7 @@ class AppGenerateService:
|
||||
:return:
|
||||
"""
|
||||
return CompletionAppGenerator().generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -157,12 +155,12 @@ class AppGenerateService:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
else:
|
||||
# fetch published workflow by app_model
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not published')
|
||||
raise ValueError("Workflow not published")
|
||||
|
||||
return workflow
|
||||
|
||||
@ -5,7 +5,6 @@ from models.model import AppMode
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
|
||||
@ -33,27 +33,22 @@ class AppService:
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
filters = [
|
||||
App.tenant_id == tenant_id,
|
||||
App.is_universal == False
|
||||
]
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args['mode'] == 'workflow':
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
|
||||
elif args['mode'] == 'chat':
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
|
||||
elif args['mode'] == 'agent-chat':
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
elif args['mode'] == 'channel':
|
||||
elif args["mode"] == "channel":
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if args.get('name'):
|
||||
name = args['name'][:30]
|
||||
filters.append(App.name.ilike(f'%{name}%'))
|
||||
if args.get('tag_ids'):
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('app',
|
||||
tenant_id,
|
||||
args['tag_ids'])
|
||||
if args.get("name"):
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
if args.get("tag_ids"):
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
|
||||
if target_ids:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
@ -61,9 +56,9 @@ class AppService:
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
page=args["page"],
|
||||
per_page=args["limit"],
|
||||
error_out=False,
|
||||
)
|
||||
|
||||
return app_models
|
||||
@ -75,21 +70,20 @@ class AppService:
|
||||
:param args: request args
|
||||
:param account: Account instance
|
||||
"""
|
||||
app_mode = AppMode.value_of(args['mode'])
|
||||
app_mode = AppMode.value_of(args["mode"])
|
||||
app_template = default_app_templates[app_mode]
|
||||
|
||||
# get model config
|
||||
default_model_config = app_template.get('model_config')
|
||||
default_model_config = app_template.get("model_config")
|
||||
default_model_config = default_model_config.copy() if default_model_config else None
|
||||
if default_model_config and 'model' in default_model_config:
|
||||
if default_model_config and "model" in default_model_config:
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
|
||||
# get default model instance
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
model_instance = None
|
||||
@ -98,32 +92,43 @@ class AppService:
|
||||
model_instance = None
|
||||
|
||||
if model_instance:
|
||||
if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']:
|
||||
default_model_dict = default_model_config['model']
|
||||
if (
|
||||
model_instance.model == default_model_config["model"]["name"]
|
||||
and model_instance.provider == default_model_config["model"]["provider"]
|
||||
):
|
||||
default_model_dict = default_model_config["model"]
|
||||
else:
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
default_model_dict = {
|
||||
'provider': model_instance.provider,
|
||||
'name': model_instance.model,
|
||||
'mode': model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
'completion_params': {}
|
||||
"provider": model_instance.provider,
|
||||
"name": model_instance.model,
|
||||
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
"completion_params": {},
|
||||
}
|
||||
else:
|
||||
default_model_dict = default_model_config['model']
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config["model"]["provider"] = provider
|
||||
default_model_config["model"]["name"] = model
|
||||
default_model_dict = default_model_config["model"]
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
default_model_config["model"] = json.dumps(default_model_dict)
|
||||
|
||||
app = App(**app_template['app'])
|
||||
app.name = args['name']
|
||||
app.description = args.get('description', '')
|
||||
app.mode = args['mode']
|
||||
app.icon = args['icon']
|
||||
app.icon_background = args['icon_background']
|
||||
app = App(**app_template["app"])
|
||||
app.name = args["name"]
|
||||
app.description = args.get("description", "")
|
||||
app.mode = args["mode"]
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.tenant_id = tenant_id
|
||||
app.api_rph = args.get('api_rph', 0)
|
||||
app.api_rpm = args.get('api_rpm', 0)
|
||||
app.api_rph = args.get("api_rph", 0)
|
||||
app.api_rpm = args.get("api_rpm", 0)
|
||||
app.created_by = account.id
|
||||
app.updated_by = account.id
|
||||
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
@ -131,6 +136,8 @@ class AppService:
|
||||
if default_model_config:
|
||||
app_model_config = AppModelConfig(**default_model_config)
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
@ -151,7 +158,7 @@ class AppService:
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
@ -167,7 +174,7 @@ class AppService:
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app.id}'
|
||||
identity_id=f"AGENT.{app.id}",
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
@ -178,7 +185,7 @@ class AppService:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
tool["tool_parameters"] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -189,13 +196,14 @@ class AppService:
|
||||
"""
|
||||
Modified App class
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.__dict__.update(app.__dict__)
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
return model_config
|
||||
|
||||
|
||||
app = ModifiedApp(app)
|
||||
|
||||
return app
|
||||
@ -207,11 +215,13 @@ class AppService:
|
||||
:param args: request args
|
||||
:return: App instance
|
||||
"""
|
||||
app.name = args.get('name')
|
||||
app.description = args.get('description', '')
|
||||
app.max_active_requests = args.get('max_active_requests')
|
||||
app.icon = args.get('icon')
|
||||
app.icon_background = args.get('icon_background')
|
||||
app.name = args.get("name")
|
||||
app.description = args.get("description", "")
|
||||
app.max_active_requests = args.get("max_active_requests")
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args.get("icon")
|
||||
app.icon_background = args.get("icon_background")
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -228,6 +238,7 @@ class AppService:
|
||||
:return: App instance
|
||||
"""
|
||||
app.name = name
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -243,6 +254,7 @@ class AppService:
|
||||
"""
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -259,6 +271,7 @@ class AppService:
|
||||
return app
|
||||
|
||||
app.enable_site = enable_site
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -275,6 +288,7 @@ class AppService:
|
||||
return app
|
||||
|
||||
app.enable_api = enable_api
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -289,10 +303,7 @@ class AppService:
|
||||
db.session.commit()
|
||||
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id
|
||||
)
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
"""
|
||||
@ -302,9 +313,7 @@ class AppService:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
meta = {"tool_icons": {}}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
workflow = app_model.workflow
|
||||
@ -312,17 +321,19 @@ class AppService:
|
||||
return meta
|
||||
|
||||
graph = workflow.graph_dict
|
||||
nodes = graph.get('nodes', [])
|
||||
nodes = graph.get("nodes", [])
|
||||
tools = []
|
||||
for node in nodes:
|
||||
if node.get('data', {}).get('type') == 'tool':
|
||||
node_data = node.get('data', {})
|
||||
tools.append({
|
||||
'provider_type': node_data.get('provider_type'),
|
||||
'provider_id': node_data.get('provider_id'),
|
||||
'tool_name': node_data.get('tool_name'),
|
||||
'tool_parameters': {}
|
||||
})
|
||||
if node.get("data", {}).get("type") == "tool":
|
||||
node_data = node.get("data", {})
|
||||
tools.append(
|
||||
{
|
||||
"provider_type": node_data.get("provider_type"),
|
||||
"provider_id": node_data.get("provider_id"),
|
||||
"tool_name": node_data.get("tool_name"),
|
||||
"tool_parameters": {},
|
||||
}
|
||||
)
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
@ -332,30 +343,26 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
tools = agent_config.get("tools", [])
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
provider_type = tool.get("provider_type")
|
||||
provider_id = tool.get("provider_id")
|
||||
tool_name = tool.get("tool_name")
|
||||
if provider_type == "builtin":
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
).first()
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
|
||||
)
|
||||
meta["tool_icons"][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return meta
|
||||
|
||||
@ -17,7 +17,7 @@ from services.errors.audio import (
|
||||
|
||||
FILE_SIZE = 30
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr']
|
||||
ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,19 +31,19 @@ class AudioService:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'):
|
||||
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
if not app_model_config.speech_to_text_dict["enabled"]:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
if file is None:
|
||||
raise NoAudioUploadedServiceError()
|
||||
|
||||
extension = file.mimetype
|
||||
if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
|
||||
if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]:
|
||||
raise UnsupportedAudioTypeServiceError()
|
||||
|
||||
file_content = file.read()
|
||||
@ -55,20 +55,25 @@ class AudioService:
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.SPEECH2TEXT
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
|
||||
)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportSpeechToTextServiceError()
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.mp3'
|
||||
buffer.name = "temp.mp3"
|
||||
|
||||
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
|
||||
|
||||
@classmethod
|
||||
def transcript_tts(cls, app_model: App, text: Optional[str] = None,
|
||||
voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None):
|
||||
def transcript_tts(
|
||||
cls,
|
||||
app_model: App,
|
||||
text: Optional[str] = None,
|
||||
voice: Optional[str] = None,
|
||||
end_user: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@ -84,65 +89,56 @@ class AudioService:
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'):
|
||||
if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = features_dict['text_to_speech'].get('voice') if voice is None else voice
|
||||
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
|
||||
else:
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
|
||||
if not text_to_speech_dict.get('enabled'):
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get('voice') if voice is None else voice
|
||||
voice = text_to_speech_dict.get("voice") if voice is None else voice
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
try:
|
||||
if not voice:
|
||||
voices = model_instance.get_tts_voices()
|
||||
if voices:
|
||||
voice = voices[0].get('value')
|
||||
voice = voices[0].get("value")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user=end_user,
|
||||
tenant_id=app_model.tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if message_id:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id
|
||||
).first()
|
||||
if message.answer == '' and message.status == 'normal':
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
if message.answer == "" and message.status == "normal":
|
||||
return None
|
||||
|
||||
else:
|
||||
response = invoke_tts(message.answer, app_model=app_model, voice=voice)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type='audio/mpeg')
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
else:
|
||||
response = invoke_tts(text, app_model, voice)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type='audio/mpeg')
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def transcript_tts_voices(cls, tenant_id: str, language: str):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportTextToSpeechServiceError()
|
||||
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
if provider == "firecrawl":
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
|
||||
@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
.all()
|
||||
)
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
@ -47,24 +51,24 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
.first()
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
raise ValueError("provider is required")
|
||||
if "credentials" not in args or not args["credentials"]:
|
||||
raise ValueError("credentials is required")
|
||||
if not isinstance(args["credentials"], dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
self.api_key = credentials.get("config").get("api_key", None)
|
||||
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
"url": "https://example.com",
|
||||
"crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
|
||||
"pageOptions": {"onlyMainContent": True},
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
|
||||
@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {'tenant_id': tenant_id}
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request('GET', '/subscription/info', params=params)
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_subscription(cls, plan: str,
|
||||
interval: str,
|
||||
prefilled_email: str = '',
|
||||
tenant_id: str = ''):
|
||||
params = {
|
||||
'plan': plan,
|
||||
'interval': interval,
|
||||
'prefilled_email': prefilled_email,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/subscription/payment-link', params=params)
|
||||
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/subscription/payment-link", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_model_provider_payment_link(cls,
|
||||
provider_name: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
prefilled_email: str):
|
||||
def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
|
||||
params = {
|
||||
'provider_name': provider_name,
|
||||
'tenant_id': tenant_id,
|
||||
'account_id': account_id,
|
||||
'prefilled_email': prefilled_email
|
||||
"provider_name": provider_name,
|
||||
"tenant_id": tenant_id,
|
||||
"account_id": account_id,
|
||||
"prefilled_email": prefilled_email,
|
||||
}
|
||||
return cls._send_request('GET', '/model-provider/payment-link', params=params)
|
||||
return cls._send_request("GET", "/model-provider/payment-link", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''):
|
||||
params = {
|
||||
'prefilled_email': prefilled_email,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/invoices', params=params)
|
||||
def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/invoices", params=params)
|
||||
|
||||
@classmethod
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Billing-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
@ -69,10 +51,11 @@ class BillingService:
|
||||
def is_tenant_owner_or_admin(current_user):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not TenantAccountRole.is_privileged_role(join.role):
|
||||
raise ValueError('Only team owner or team admin can perform this action')
|
||||
raise ValueError("Only team owner or team admin can perform this action")
|
||||
|
||||
@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [{
|
||||
'name': module_extension.name,
|
||||
'label': module_extension.label,
|
||||
'form_schema': module_extension.form_schema
|
||||
} for module_extension in module_extensions if not module_extension.builtin]
|
||||
return [
|
||||
{
|
||||
"name": module_extension.name,
|
||||
"label": module_extension.label,
|
||||
"form_schema": module_extension.form_schema,
|
||||
}
|
||||
for module_extension in module_extensions
|
||||
if not module_extension.builtin
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import asc, desc, or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -14,21 +15,27 @@ from services.errors.message import MessageNotExistsError
|
||||
|
||||
class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None,
|
||||
sort_by: str = "-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Conversation).filter(
|
||||
Conversation.is_deleted == False,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
||||
)
|
||||
|
||||
if include_ids is not None:
|
||||
@ -37,47 +44,67 @@ class ConversationService:
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
).first()
|
||||
# define sort fields and directions
|
||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(Conversation.id == last_id).first()
|
||||
if not last_conversation:
|
||||
raise LastConversationNotExistsError()
|
||||
|
||||
conversations = base_query.filter(
|
||||
Conversation.created_at < last_conversation.created_at,
|
||||
Conversation.id != last_conversation.id
|
||||
).order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
# build filters based on sorting
|
||||
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
|
||||
base_query = base_query.filter(filter_condition)
|
||||
|
||||
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
|
||||
|
||||
conversations = base_query.limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_first_conversation = conversations[-1]
|
||||
rest_count = base_query.filter(
|
||||
Conversation.created_at < current_page_first_conversation.created_at,
|
||||
Conversation.id != current_page_first_conversation.id
|
||||
).count()
|
||||
current_page_last_conversation = conversations[-1]
|
||||
rest_filter_condition = cls._build_filter_condition(
|
||||
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
|
||||
)
|
||||
rest_count = base_query.filter(rest_filter_condition).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=conversations,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
|
||||
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
|
||||
if sort_by.startswith("-"):
|
||||
return sort_by[1:], desc
|
||||
return sort_by, asc
|
||||
|
||||
@classmethod
|
||||
def _build_filter_condition(
|
||||
cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False
|
||||
):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
else:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
@classmethod
|
||||
def rename(
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
name: str,
|
||||
auto_generate: bool,
|
||||
):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
if auto_generate:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
@ -85,11 +112,12 @@ class ConversationService:
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
@ -109,15 +137,18 @@ class ConversationService:
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
conversation = db.session.query(Conversation) \
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False
|
||||
).first()
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -4,15 +4,12 @@ import requests
|
||||
|
||||
|
||||
class EnterpriseRequest:
|
||||
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
|
||||
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Enterprise-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
@ -2,7 +2,10 @@ from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/info')
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
@classmethod
|
||||
def get_app_web_sso_enabled(cls, app_code):
|
||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
||||
|
||||
@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
NO_CONFIGURE = 'no-configure'
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration response.
|
||||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
|
||||
|
||||
@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration response.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
@ -46,6 +49,7 @@ class ProviderResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
@ -67,18 +71,15 @@ class ProviderResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: SimpleProviderEntityResponse
|
||||
@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
|
||||
@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError):
|
||||
|
||||
class RateLimitExceededError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
class BaseServiceError(Exception):
|
||||
def __init__(self, description: str = None):
|
||||
self.description = description
|
||||
self.description = description
|
||||
|
||||
19
api/services/errors/llm.py
Normal file
19
api/services/errors/llm.py
Normal file
@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class SubscriptionModel(BaseModel):
|
||||
plan: str = 'sandbox'
|
||||
interval: str = ''
|
||||
plan: str = "sandbox"
|
||||
interval: str = ""
|
||||
|
||||
|
||||
class BillingModel(BaseModel):
|
||||
@ -27,7 +27,7 @@ class FeatureModel(BaseModel):
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
|
||||
docs_processing: str = 'standard'
|
||||
docs_processing: str = "standard"
|
||||
can_replace_logo: bool = False
|
||||
model_load_balancing_enabled: bool = False
|
||||
dataset_operator_enabled: bool = False
|
||||
@ -38,13 +38,13 @@ class FeatureModel(BaseModel):
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ''
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
sso_enforced_for_web: bool = False
|
||||
sso_enforced_for_web_protocol: str = ''
|
||||
sso_enforced_for_web_protocol: str = ""
|
||||
enable_web_sso_switch_component: bool = False
|
||||
|
||||
|
||||
class FeatureService:
|
||||
|
||||
@classmethod
|
||||
def get_features(cls, tenant_id: str) -> FeatureModel:
|
||||
features = FeatureModel()
|
||||
@ -61,6 +61,7 @@ class FeatureService:
|
||||
system_features = SystemFeatureModel()
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
system_features.enable_web_sso_switch_component = True
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
|
||||
return system_features
|
||||
@ -75,44 +76,44 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info['enabled']
|
||||
features.billing.subscription.plan = billing_info['subscription']['plan']
|
||||
features.billing.subscription.interval = billing_info['subscription']['interval']
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
features.billing.subscription.interval = billing_info["subscription"]["interval"]
|
||||
|
||||
if 'members' in billing_info:
|
||||
features.members.size = billing_info['members']['size']
|
||||
features.members.limit = billing_info['members']['limit']
|
||||
if "members" in billing_info:
|
||||
features.members.size = billing_info["members"]["size"]
|
||||
features.members.limit = billing_info["members"]["limit"]
|
||||
|
||||
if 'apps' in billing_info:
|
||||
features.apps.size = billing_info['apps']['size']
|
||||
features.apps.limit = billing_info['apps']['limit']
|
||||
if "apps" in billing_info:
|
||||
features.apps.size = billing_info["apps"]["size"]
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if 'vector_space' in billing_info:
|
||||
features.vector_space.size = billing_info['vector_space']['size']
|
||||
features.vector_space.limit = billing_info['vector_space']['limit']
|
||||
if "vector_space" in billing_info:
|
||||
features.vector_space.size = billing_info["vector_space"]["size"]
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
if 'documents_upload_quota' in billing_info:
|
||||
features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
|
||||
features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
|
||||
if "documents_upload_quota" in billing_info:
|
||||
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
|
||||
features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
|
||||
|
||||
if 'annotation_quota_limit' in billing_info:
|
||||
features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
|
||||
features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
|
||||
if "annotation_quota_limit" in billing_info:
|
||||
features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
|
||||
features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
|
||||
|
||||
if 'docs_processing' in billing_info:
|
||||
features.docs_processing = billing_info['docs_processing']
|
||||
if "docs_processing" in billing_info:
|
||||
features.docs_processing = billing_info["docs_processing"]
|
||||
|
||||
if 'can_replace_logo' in billing_info:
|
||||
features.can_replace_logo = billing_info['can_replace_logo']
|
||||
if "can_replace_logo" in billing_info:
|
||||
features.can_replace_logo = billing_info["can_replace_logo"]
|
||||
|
||||
if 'model_load_balancing_enabled' in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
|
||||
if "model_load_balancing_enabled" in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
|
||||
features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
|
||||
features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']
|
||||
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
|
||||
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
|
||||
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
|
||||
|
||||
@ -17,27 +17,45 @@ from models.account import Account
|
||||
from models.model import EndUser, UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv']
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub']
|
||||
ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS = [
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"ppt",
|
||||
"xml",
|
||||
"epub",
|
||||
]
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
filename = file.filename
|
||||
extension = file.filename.split('.')[-1]
|
||||
extension = file.filename.split(".")[-1]
|
||||
if len(filename) > 200:
|
||||
filename = filename.split('.')[0][:200] + '.' + extension
|
||||
filename = filename.split(".")[0][:200] + "." + extension
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
|
||||
allowed_extensions = (
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||
if etl_type == "Unstructured"
|
||||
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||
)
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
@ -55,7 +73,7 @@ class FileService:
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
message = f"File size exceeded. {file_size} > {file_size_limit}"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
# user uuid as file name
|
||||
@ -67,7 +85,7 @@ class FileService:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
|
||||
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
@ -81,11 +99,11 @@ class FileService:
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
|
||||
created_by_role=("account" if isinstance(user, Account) else "end_user"),
|
||||
created_by=user.id,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
hash=hashlib.sha3_256(file_content).hexdigest(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
@ -99,25 +117,25 @@ class FileService:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
|
||||
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, text.encode('utf-8'))
|
||||
storage.save(file_key, text.encode("utf-8"))
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=text_name + '.txt',
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
@ -127,9 +145,7 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -137,12 +153,12 @@ class FileService:
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
|
||||
|
||||
return text
|
||||
|
||||
@ -152,9 +168,7 @@ class FileService:
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -170,9 +184,7 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
@ -9,14 +9,11 @@ from models.account import Account
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -27,9 +24,9 @@ class HitTestingService:
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
"tsne_position": {'x': 0, 'y': 0},
|
||||
"tsne_position": {"x": 0, "y": 0},
|
||||
},
|
||||
"records": []
|
||||
"records": [],
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
@ -38,27 +35,28 @@ class HitTestingService:
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
top_k=retrieval_model.get('top_k', 2),
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode', None),
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrival_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else None,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source='hit_testing',
|
||||
created_by_role='account',
|
||||
created_by=account.id
|
||||
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
@ -71,14 +69,18 @@ class HitTestingService:
|
||||
i = 0
|
||||
records = []
|
||||
for document in documents:
|
||||
index_node_id = document.metadata['doc_id']
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
i += 1
|
||||
@ -86,7 +88,7 @@ class HitTestingService:
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get('score', None),
|
||||
"score": document.metadata.get("score", None),
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
@ -97,15 +99,15 @@ class HitTestingService:
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"records": records
|
||||
"records": records,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
|
||||
@ -7,7 +7,8 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
@ -26,8 +27,14 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
class MessageService:
|
||||
@classmethod
|
||||
def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
def pagination_by_first_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
conversation_id: str,
|
||||
first_id: Optional[str],
|
||||
limit: int,
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -35,52 +42,69 @@ class MessageService:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if first_id:
|
||||
first_message = db.session.query(Message) \
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise FirstMessageNotExistsError()
|
||||
|
||||
history_messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id
|
||||
) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -88,9 +112,7 @@ class MessageService:
|
||||
|
||||
if conversation_id is not None:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
base_query = base_query.filter(Message.conversation_id == conversation.id)
|
||||
@ -104,10 +126,12 @@ class MessageService:
|
||||
if not last_message:
|
||||
raise LastMessageNotExistsError()
|
||||
|
||||
history_messages = base_query.filter(
|
||||
Message.created_at < last_message.created_at,
|
||||
Message.id != last_message.id
|
||||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
|
||||
@ -115,30 +139,22 @@ class MessageService:
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = base_query.filter(
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]],
|
||||
rating: Optional[str]) -> MessageFeedback:
|
||||
def create_feedback(
|
||||
cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
|
||||
) -> MessageFeedback:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
raise ValueError("user cannot be None")
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
|
||||
|
||||
@ -147,14 +163,14 @@ class MessageService:
|
||||
elif rating and feedback:
|
||||
feedback.rating = rating
|
||||
elif not rating and not feedback:
|
||||
raise ValueError('rating cannot be None when feedback not exists')
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating,
|
||||
from_source=('user' if isinstance(user, EndUser) else 'admin'),
|
||||
from_source=("user" if isinstance(user, EndUser) else "admin"),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
@ -166,13 +182,17 @@ class MessageService:
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
@ -180,27 +200,22 @@ class MessageService:
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
|
||||
def get_suggested_questions_after_answer(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
|
||||
) -> list[Message]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
raise ValueError("user cannot be None")
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=message.conversation_id,
|
||||
user=user
|
||||
app_model=app_model, conversation_id=message.conversation_id, user=user
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
model_manager = ModelManager()
|
||||
@ -215,24 +230,23 @@ class MessageService:
|
||||
if workflow is None:
|
||||
return []
|
||||
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
if not app_config.additional_features.suggested_questions_after_answer:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
else:
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
@ -248,16 +262,13 @@ class MessageService:
|
||||
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider=app_model_config.model_dict['provider'],
|
||||
provider=app_model_config.model_dict["provider"],
|
||||
model_type=ModelType.LLM,
|
||||
model=app_model_config.model_dict['name']
|
||||
model=app_model_config.model_dict["name"],
|
||||
)
|
||||
|
||||
# get memory of conversation (read-only)
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
histories = memory.get_history_prompt_text(
|
||||
max_token_limit=3000,
|
||||
@ -266,18 +277,14 @@ class MessageService:
|
||||
|
||||
with measure_time() as timer:
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id,
|
||||
histories=histories
|
||||
tenant_id=app_model.tenant_id, histories=histories
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
message_id=message_id,
|
||||
suggested_question=questions,
|
||||
timer=timer
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
@ -22,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@ -45,10 +45,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
"""
|
||||
@ -69,13 +66,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# disable model load balancing
|
||||
provider_configuration.disable_model_load_balancing(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
|
||||
-> tuple[bool, list[dict]]:
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -106,20 +101,24 @@ class ModelLoadBalancingService:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).order_by(LoadBalancingModelConfig.created_at).all()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
if provider_configuration.custom_configuration.provider:
|
||||
# check if the inherit configuration exists,
|
||||
# inherit is represented for the provider or model custom credentials
|
||||
inherit_config_exists = False
|
||||
for load_balancing_config in load_balancing_configs:
|
||||
if load_balancing_config.name == '__inherit__':
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config_exists = True
|
||||
break
|
||||
|
||||
@ -132,7 +131,7 @@ class ModelLoadBalancingService:
|
||||
else:
|
||||
# move the inherit configuration to the first
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
|
||||
if load_balancing_config.name == '__inherit__':
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config = load_balancing_configs.pop(i)
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
|
||||
@ -150,7 +149,7 @@ class ModelLoadBalancingService:
|
||||
provider=provider,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
config_id=load_balancing_config.id
|
||||
config_id=load_balancing_config.id,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -171,32 +170,32 @@ class ModelLoadBalancingService:
|
||||
if variable in credentials:
|
||||
try:
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Obfuscate credentials
|
||||
credentials = provider_configuration.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
datas.append({
|
||||
'id': load_balancing_config.id,
|
||||
'name': load_balancing_config.name,
|
||||
'credentials': credentials,
|
||||
'enabled': load_balancing_config.enabled,
|
||||
'in_cooldown': in_cooldown,
|
||||
'ttl': ttl
|
||||
})
|
||||
datas.append(
|
||||
{
|
||||
"id": load_balancing_config.id,
|
||||
"name": load_balancing_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_config.enabled,
|
||||
"in_cooldown": in_cooldown,
|
||||
"ttl": ttl,
|
||||
}
|
||||
)
|
||||
|
||||
return is_load_balancing_enabled, datas
|
||||
|
||||
def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
|
||||
-> Optional[dict]:
|
||||
def get_load_balancing_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get load balancing configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -218,14 +217,17 @@ class ModelLoadBalancingService:
|
||||
model_type = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id
|
||||
).first()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
return None
|
||||
@ -243,19 +245,19 @@ class ModelLoadBalancingService:
|
||||
|
||||
# Obfuscate credentials
|
||||
credentials = provider_configuration.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
return {
|
||||
'id': load_balancing_model_config.id,
|
||||
'name': load_balancing_model_config.name,
|
||||
'credentials': credentials,
|
||||
'enabled': load_balancing_model_config.enabled
|
||||
"id": load_balancing_model_config.id,
|
||||
"name": load_balancing_model_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_model_config.enabled,
|
||||
}
|
||||
|
||||
def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
|
||||
-> LoadBalancingModelConfig:
|
||||
def _init_inherit_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: ModelType
|
||||
) -> LoadBalancingModelConfig:
|
||||
"""
|
||||
Initialize the inherit configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -270,18 +272,16 @@ class ModelLoadBalancingService:
|
||||
provider_name=provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name='__inherit__'
|
||||
name="__inherit__",
|
||||
)
|
||||
db.session.add(inherit_config)
|
||||
db.session.commit()
|
||||
|
||||
return inherit_config
|
||||
|
||||
def update_load_balancing_configs(self, tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
configs: list[dict]) -> None:
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
|
||||
) -> None:
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -303,15 +303,18 @@ class ModelLoadBalancingService:
|
||||
model_type = ModelType.value_of(model_type)
|
||||
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError('Invalid load balancing configs')
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
|
||||
current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
current_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).all()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# id as key, config as value
|
||||
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
|
||||
@ -319,25 +322,25 @@ class ModelLoadBalancingService:
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Invalid load balancing config')
|
||||
raise ValueError("Invalid load balancing config")
|
||||
|
||||
config_id = config.get('id')
|
||||
name = config.get('name')
|
||||
credentials = config.get('credentials')
|
||||
enabled = config.get('enabled')
|
||||
config_id = config.get("id")
|
||||
name = config.get("name")
|
||||
credentials = config.get("credentials")
|
||||
enabled = config.get("enabled")
|
||||
|
||||
if not name:
|
||||
raise ValueError('Invalid load balancing config name')
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if enabled is None:
|
||||
raise ValueError('Invalid load balancing config enabled')
|
||||
raise ValueError("Invalid load balancing config enabled")
|
||||
|
||||
# is config exists
|
||||
if config_id:
|
||||
config_id = str(config_id)
|
||||
|
||||
if config_id not in current_load_balancing_configs_dict:
|
||||
raise ValueError('Invalid load balancing config id: {}'.format(config_id))
|
||||
raise ValueError("Invalid load balancing config id: {}".format(config_id))
|
||||
|
||||
updated_config_ids.add(config_id)
|
||||
|
||||
@ -346,11 +349,11 @@ class ModelLoadBalancingService:
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
|
||||
raise ValueError('Load balancing config name {} already exists'.format(name))
|
||||
raise ValueError("Load balancing config name {} already exists".format(name))
|
||||
|
||||
if credentials:
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
@ -360,7 +363,7 @@ class ModelLoadBalancingService:
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_config,
|
||||
validate=False
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# update load balancing config
|
||||
@ -374,19 +377,19 @@ class ModelLoadBalancingService:
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
else:
|
||||
# create load balancing config
|
||||
if name == '__inherit__':
|
||||
raise ValueError('Invalid load balancing config name')
|
||||
if name == "__inherit__":
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.name == name:
|
||||
raise ValueError('Load balancing config name {} already exists'.format(name))
|
||||
raise ValueError("Load balancing config name {} already exists".format(name))
|
||||
|
||||
if not credentials:
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
@ -395,7 +398,7 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# create load balancing config
|
||||
@ -405,7 +408,7 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials)
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(load_balancing_model_config)
|
||||
@ -419,12 +422,15 @@ class ModelLoadBalancingService:
|
||||
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
|
||||
def validate_load_balancing_credentials(self, tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None) -> None:
|
||||
def validate_load_balancing_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate load balancing credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -449,14 +455,17 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
# Get load balancing config
|
||||
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id
|
||||
).first()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
raise ValueError(f"Load balancing config {config_id} does not exist.")
|
||||
@ -468,16 +477,19 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_model_config
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
)
|
||||
|
||||
def _custom_credentials_validate(self, tenant_id: str,
|
||||
provider_configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True) -> dict:
|
||||
def _custom_credentials_validate(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -511,7 +523,7 @@ class ModelLoadBalancingService:
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
|
||||
|
||||
if validate:
|
||||
@ -520,12 +532,11 @@ class ModelLoadBalancingService:
|
||||
provider=provider_configuration.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=provider_configuration.provider.provider,
|
||||
credentials=credentials
|
||||
provider=provider_configuration.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
@ -534,8 +545,9 @@ class ModelLoadBalancingService:
|
||||
|
||||
return credentials
|
||||
|
||||
def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
|
||||
-> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""
|
||||
Get form schemas.
|
||||
:param provider_configuration: provider configuration
|
||||
@ -557,9 +569,7 @@ class ModelLoadBalancingService:
|
||||
:return:
|
||||
"""
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=config_id,
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
@ -30,6 +30,7 @@ class ModelProviderService:
|
||||
"""
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@ -72,8 +73,8 @@ class ModelProviderService:
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
current_quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
quota_configurations=provider_configuration.system_configuration.quota_configurations
|
||||
)
|
||||
quota_configurations=provider_configuration.system_configuration.quota_configurations,
|
||||
),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
@ -94,9 +95,9 @@ class ModelProviderService:
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
|
||||
provider=provider
|
||||
)]
|
||||
return [
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
|
||||
"""
|
||||
@ -194,13 +195,12 @@ class ModelProviderService:
|
||||
|
||||
# Get model custom credentials from ProviderModel if exists
|
||||
return provider_configuration.get_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
obfuscated=True
|
||||
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
|
||||
)
|
||||
|
||||
def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
def model_credentials_validate(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@ -221,13 +221,12 @@ class ModelProviderService:
|
||||
|
||||
# Validate model credentials
|
||||
provider_configuration.custom_model_credentials_validate(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
def save_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
save model credentials.
|
||||
|
||||
@ -248,9 +247,7 @@ class ModelProviderService:
|
||||
|
||||
# Add or update custom model credentials
|
||||
provider_configuration.add_or_update_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
@ -272,10 +269,7 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom model credentials
|
||||
provider_configuration.delete_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model
|
||||
)
|
||||
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
@ -289,9 +283,7 @@ class ModelProviderService:
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
@ -322,16 +314,19 @@ class ModelProviderService:
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE,
|
||||
models=[ProviderModelWithStatusEntity(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
load_balancing_enabled=model.load_balancing_enabled
|
||||
) for model in models]
|
||||
models=[
|
||||
ProviderModelWithStatusEntity(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
load_balancing_enabled=model.load_balancing_enabled,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@ -360,19 +355,13 @@ class ModelProviderService:
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model
|
||||
)
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return model_type_instance.get_parameter_rules(
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@ -383,22 +372,26 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
result = self.provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
try:
|
||||
return (
|
||||
DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types,
|
||||
),
|
||||
)
|
||||
if result
|
||||
else None
|
||||
)
|
||||
) if result else None
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
"""
|
||||
@ -412,13 +405,12 @@ class ModelProviderService:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
self.provider_manager.update_default_model_record(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum,
|
||||
provider=provider,
|
||||
model=model
|
||||
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
|
||||
)
|
||||
|
||||
def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]:
|
||||
def get_model_provider_icon(
|
||||
self, provider: str, icon_type: str, lang: str
|
||||
) -> tuple[Optional[bytes], Optional[str]]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
@ -430,11 +422,11 @@ class ModelProviderService:
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
if icon_type.lower() == 'icon_small':
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
@ -442,13 +434,15 @@ class ModelProviderService:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
|
||||
provider_instance_path = os.path.dirname(
|
||||
os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
|
||||
)
|
||||
file_path = os.path.join(provider_instance_path, "_assets")
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
|
||||
@ -456,10 +450,10 @@ class ModelProviderService:
|
||||
return None, None
|
||||
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
mimetype = mimetype or 'application/octet-stream'
|
||||
mimetype = mimetype or "application/octet-stream"
|
||||
|
||||
# read binary from file
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
byte_data = f.read()
|
||||
return byte_data, mimetype
|
||||
|
||||
@ -505,10 +499,7 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration.enable_model(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
"""
|
||||
@ -529,78 +520,49 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration.disable_model(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/apply'
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {response.json()['message']}"
|
||||
)
|
||||
if response.json()["code"] != "success":
|
||||
raise ValueError(f"error: {response.json()['message']}")
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst['type'] == 'redirect':
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'redirect_url': rst['redirect_url']
|
||||
}
|
||||
if rst["type"] == "redirect":
|
||||
return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
|
||||
else:
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'result': 'success'
|
||||
}
|
||||
return {"type": rst["type"], "result": "success"}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/qualification-verify'
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
json_data = {'workspace_id': tenant_id, 'provider_name': provider}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
json_data = {"workspace_id": tenant_id, "provider_name": provider}
|
||||
if token:
|
||||
json_data['token'] = token
|
||||
response = requests.post(api_url, headers=headers,
|
||||
json=json_data)
|
||||
json_data["token"] = token
|
||||
response = requests.post(api_url, headers=headers, json=json_data)
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {rst['message']}"
|
||||
)
|
||||
if rst["code"] != "success":
|
||||
raise ValueError(f"error: {rst['message']}")
|
||||
|
||||
data = rst['data']
|
||||
if data['qualified'] is True:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': True
|
||||
}
|
||||
data = rst["data"]
|
||||
if data["qualified"] is True:
|
||||
return {"result": "success", "provider_name": provider, "flag": True}
|
||||
else:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': False,
|
||||
'reason': data['reason']
|
||||
}
|
||||
return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}
|
||||
|
||||
@ -4,17 +4,18 @@ from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class ModerationService:
|
||||
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: AppModelConfig = None
|
||||
|
||||
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
config = app_model_config.sensitive_word_avoidance_dict['config']
|
||||
name = app_model_config.sensitive_word_avoidance_dict["type"]
|
||||
config = app_model_config.sensitive_word_avoidance_dict["config"]
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
||||
|
||||
@ -4,15 +4,12 @@ import requests
|
||||
|
||||
|
||||
class OperationService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Billing-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
@ -22,11 +19,11 @@ class OperationService:
|
||||
@classmethod
|
||||
def record_utm(cls, tenant_id: str, utm_info: dict):
|
||||
params = {
|
||||
'tenant_id': tenant_id,
|
||||
'utm_source': utm_info.get('utm_source', ''),
|
||||
'utm_medium': utm_info.get('utm_medium', ''),
|
||||
'utm_campaign': utm_info.get('utm_campaign', ''),
|
||||
'utm_content': utm_info.get('utm_content', ''),
|
||||
'utm_term': utm_info.get('utm_term', '')
|
||||
"tenant_id": tenant_id,
|
||||
"utm_source": utm_info.get("utm_source", ""),
|
||||
"utm_medium": utm_info.get("utm_medium", ""),
|
||||
"utm_campaign": utm_info.get("utm_campaign", ""),
|
||||
"utm_content": utm_info.get("utm_content", ""),
|
||||
"utm_term": utm_info.get("utm_term", ""),
|
||||
}
|
||||
return cls._send_request('POST', '/tenant_utms', params=params)
|
||||
return cls._send_request("POST", "/tenant_utms", params=params)
|
||||
|
||||
@ -12,20 +12,29 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config_data: TraceAppConfig = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token and obfuscated_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
|
||||
decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
|
||||
|
||||
trace_config_data.tracing_config = decrypt_tracing_config
|
||||
if tracing_provider == "langfuse" and (
|
||||
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
|
||||
):
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_key": project_key})
|
||||
|
||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@classmethod
|
||||
@ -37,11 +46,13 @@ class OpsService:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider != None:
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['other_keys']
|
||||
config_class, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
default_config_instance = config_class(**tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
@ -51,10 +62,15 @@ class OpsService:
|
||||
if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider):
|
||||
return {"error": "Invalid Credentials"}
|
||||
|
||||
# get project key
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config_data: TraceAppConfig = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if trace_config_data:
|
||||
return None
|
||||
@ -62,6 +78,8 @@ class OpsService:
|
||||
# get tenant id
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
|
||||
if tracing_provider == "langfuse" and project_key:
|
||||
tracing_config["project_key"] = project_key
|
||||
trace_config_data = TraceAppConfig(
|
||||
app_id=app_id,
|
||||
tracing_provider=tracing_provider,
|
||||
@ -85,9 +103,11 @@ class OpsService:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
# check if trace config already exists
|
||||
current_trace_config = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
current_trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not current_trace_config:
|
||||
return None
|
||||
@ -117,9 +137,11 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not trace_config:
|
||||
return None
|
||||
|
||||
@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecommendedAppService:
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
@ -27,21 +26,21 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
if mode == 'remote':
|
||||
if mode == "remote":
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.')
|
||||
logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
elif mode == 'db':
|
||||
elif mode == "db":
|
||||
result = cls._fetch_recommended_apps_from_db(language)
|
||||
elif mode == 'builtin':
|
||||
elif mode == "builtin":
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended apps mode: {mode}')
|
||||
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
||||
|
||||
if not result.get('recommended_apps') and language != 'en-US':
|
||||
result = cls._fetch_recommended_apps_from_builtin('en-US')
|
||||
if not result.get("recommended_apps") and language != "en-US":
|
||||
result = cls._fetch_recommended_apps_from_builtin("en-US")
|
||||
|
||||
return result
|
||||
|
||||
@ -52,16 +51,18 @@ class RecommendedAppService:
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == language
|
||||
).all()
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == languages[0]
|
||||
).all()
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
@ -75,28 +76,28 @@ class RecommendedAppService:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
'id': recommended_app.id,
|
||||
'app': {
|
||||
'id': app.id,
|
||||
'name': app.name,
|
||||
'mode': app.mode,
|
||||
'icon': app.icon,
|
||||
'icon_background': app.icon_background
|
||||
"id": recommended_app.id,
|
||||
"app": {
|
||||
"id": app.id,
|
||||
"name": app.name,
|
||||
"mode": app.mode,
|
||||
"icon": app.icon,
|
||||
"icon_background": app.icon_background,
|
||||
},
|
||||
'app_id': recommended_app.app_id,
|
||||
'description': site.description,
|
||||
'copyright': site.copyright,
|
||||
'privacy_policy': site.privacy_policy,
|
||||
'custom_disclaimer': site.custom_disclaimer,
|
||||
'category': recommended_app.category,
|
||||
'position': recommended_app.position,
|
||||
'is_listed': recommended_app.is_listed
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": site.description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category) # add category to categories
|
||||
|
||||
return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
@ -106,16 +107,16 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f'{domain}/apps?language={language}'
|
||||
url = f"{domain}/apps?language={language}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}')
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "categories" in result:
|
||||
result["categories"] = sorted(result["categories"])
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -126,7 +127,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('recommended_apps', {}).get(language)
|
||||
return builtin_data.get("recommended_apps", {}).get(language)
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
|
||||
@ -136,18 +137,18 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
if mode == 'remote':
|
||||
if mode == "remote":
|
||||
try:
|
||||
result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.')
|
||||
logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.")
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
elif mode == 'db':
|
||||
elif mode == "db":
|
||||
result = cls._fetch_recommended_app_detail_from_db(app_id)
|
||||
elif mode == 'builtin':
|
||||
elif mode == "builtin":
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended app detail mode: {mode}')
|
||||
raise ValueError(f"invalid fetch recommended app detail mode: {mode}")
|
||||
|
||||
return result
|
||||
|
||||
@ -159,7 +160,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f'{domain}/apps/{app_id}'
|
||||
url = f"{domain}/apps/{app_id}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
@ -174,10 +175,11 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.app_id == app_id
|
||||
).first()
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
return None
|
||||
@ -188,12 +190,12 @@ class RecommendedAppService:
|
||||
return None
|
||||
|
||||
return {
|
||||
'id': app_model.id,
|
||||
'name': app_model.name,
|
||||
'icon': app_model.icon,
|
||||
'icon_background': app_model.icon_background,
|
||||
'mode': app_model.mode,
|
||||
'export_data': AppDslService.export_dsl(app_model=app_model)
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -204,7 +206,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('app_details', {}).get(app_id)
|
||||
return builtin_data.get("app_details", {}).get(app_id)
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
@ -216,7 +218,7 @@ class RecommendedAppService:
|
||||
return cls.builtin_data
|
||||
|
||||
root_path = current_app.root_path
|
||||
with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f:
|
||||
with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f:
|
||||
json_data = f.read()
|
||||
data = json.loads(json_data)
|
||||
cls.builtin_data = data
|
||||
@ -229,27 +231,24 @@ class RecommendedAppService:
|
||||
Fetch all recommended apps and export datas
|
||||
:return:
|
||||
"""
|
||||
templates = {
|
||||
"recommended_apps": {},
|
||||
"app_details": {}
|
||||
}
|
||||
templates = {"recommended_apps": {}, "app_details": {}}
|
||||
for language in languages:
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.')
|
||||
logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.")
|
||||
continue
|
||||
|
||||
templates['recommended_apps'][language] = result
|
||||
templates["recommended_apps"][language] = result
|
||||
|
||||
for recommended_app in result.get('recommended_apps'):
|
||||
app_id = recommended_app.get('app_id')
|
||||
for recommended_app in result.get("recommended_apps"):
|
||||
app_id = recommended_app.get("app_id")
|
||||
|
||||
# get app detail
|
||||
app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
if not app_detail:
|
||||
continue
|
||||
|
||||
templates['app_details'][app_id] = app_detail
|
||||
templates["app_details"][app_id] = app_detail
|
||||
|
||||
return templates
|
||||
|
||||
@ -10,46 +10,48 @@ from services.message_service import MessageService
|
||||
|
||||
class SavedMessageService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
saved_messages = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).order_by(SavedMessage.created_at.desc()).all()
|
||||
def pagination_by_last_id(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
|
||||
) -> InfiniteScrollPagination:
|
||||
saved_messages = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.order_by(SavedMessage.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
message_ids = [sm.message_id for sm in saved_messages]
|
||||
|
||||
return MessageService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=message_ids
|
||||
app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saved_message:
|
||||
return
|
||||
|
||||
message = MessageService.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
saved_message = SavedMessage(
|
||||
app_id=app_model.id,
|
||||
message_id=message.id,
|
||||
created_by_role='account' if isinstance(user, Account) else 'end_user',
|
||||
created_by=user.id
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
db.session.add(saved_message)
|
||||
@ -57,12 +59,16 @@ class SavedMessageService:
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saved_message:
|
||||
return
|
||||
|
||||
@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
|
||||
query = db.session.query(
|
||||
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
|
||||
).outerjoin(
|
||||
TagBinding, Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
Tag.type == tag_type,
|
||||
Tag.tenant_id == current_tenant_id
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
|
||||
query = query.group_by(
|
||||
Tag.id
|
||||
)
|
||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
tags = db.session.query(Tag).filter(
|
||||
Tag.id.in_(tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
tag_bindings = db.session.query(
|
||||
TagBinding.target_id
|
||||
).filter(
|
||||
TagBinding.tag_id.in_(tag_ids),
|
||||
TagBinding.tenant_id == current_tenant_id
|
||||
).all()
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding.target_id)
|
||||
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
@ -51,27 +45,28 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args['name'],
|
||||
type=args['type'],
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
@ -82,7 +77,7 @@ class TagService:
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args['name']
|
||||
tag.name = args["name"]
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@ -107,20 +102,21 @@ class TagService:
|
||||
@staticmethod
|
||||
def save_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# save tag binding
|
||||
for tag_id in args['tag_ids']:
|
||||
tag_binding = db.session.query(TagBinding).filter(
|
||||
TagBinding.tag_id == tag_id,
|
||||
TagBinding.target_id == args['target_id']
|
||||
).first()
|
||||
for tag_id in args["tag_ids"]:
|
||||
tag_binding = (
|
||||
db.session.query(TagBinding)
|
||||
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
|
||||
.first()
|
||||
)
|
||||
if tag_binding:
|
||||
continue
|
||||
new_tag_binding = TagBinding(
|
||||
tag_id=tag_id,
|
||||
target_id=args['target_id'],
|
||||
target_id=args["target_id"],
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
@ -128,34 +124,34 @@ class TagService:
|
||||
@staticmethod
|
||||
def delete_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(
|
||||
TagBinding.target_id == args['target_id'],
|
||||
TagBinding.tag_id == (args['tag_id'])
|
||||
).first()
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding)
|
||||
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
|
||||
.first()
|
||||
)
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
if type == 'knowledge':
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == current_user.current_tenant_id,
|
||||
Dataset.id == target_id
|
||||
).first()
|
||||
if type == "knowledge":
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == 'app':
|
||||
app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.id == target_id
|
||||
).first()
|
||||
elif type == "app":
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
|
||||
.first()
|
||||
)
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
else:
|
||||
raise NotFound("Invalid binding type")
|
||||
|
||||
|
||||
@ -29,111 +29,107 @@ class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
required=True,
|
||||
default='none',
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(
|
||||
en_US='None',
|
||||
zh_Hans='无'
|
||||
)),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(
|
||||
en_US='Api Key',
|
||||
zh_Hans='Api Key'
|
||||
)),
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(
|
||||
en_US='Select auth type',
|
||||
zh_Hans='选择认证方式'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key header',
|
||||
zh_Hans='输入 api key header,如:X-API-KEY'
|
||||
),
|
||||
default='api_key',
|
||||
help=I18nObject(
|
||||
en_US='HTTP header name for api key',
|
||||
zh_Hans='HTTP 头部字段名,用于传递 api key'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key',
|
||||
zh_Hans='输入 api key'
|
||||
),
|
||||
default=''
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
})
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
create api tool provider
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f'provider {provider_name} already exists')
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError('the number of apis should be less than 100')
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
@ -142,19 +138,19 @@ class ApiToolManageService:
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get('description', ''),
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@ -172,14 +168,12 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(
|
||||
user_id: str, tenant_id: str, url: str
|
||||
):
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
@ -189,84 +183,98 @@ class ApiToolManageService:
|
||||
try:
|
||||
response = get(url, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Got status code {response.status_code}')
|
||||
raise ValueError(f"Got status code {response.status_code}")
|
||||
schema = response.text
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
logger.error(f"parse api schema error: {str(e)}")
|
||||
raise ValueError('invalid schema, please check the url you provided')
|
||||
|
||||
return {
|
||||
'schema': schema
|
||||
}
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool_bundle,
|
||||
labels=labels,
|
||||
) for tool_bundle in provider.tools
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'api provider {provider_name} does not exists')
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get('description', '')
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
@ -295,84 +303,91 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
delete tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get api tool provider
|
||||
get api tool provider
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
test api tool before adding api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema_type}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema_type}")
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
raise ValueError('invalid schema')
|
||||
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f'invalid tool name {tool_name}')
|
||||
|
||||
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
name="",
|
||||
icon="",
|
||||
schema=schema,
|
||||
description='',
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@ -381,10 +396,7 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
@ -396,27 +408,27 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return { 'error': str(e) }
|
||||
|
||||
return { 'result': result or 'empty response' }
|
||||
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
list api tools
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
@ -425,26 +437,21 @@ class ApiToolManageService:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller,
|
||||
db_provider=provider,
|
||||
decrypt_credentials=True
|
||||
provider_controller, db_provider=provider, decrypt_credentials=True
|
||||
)
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_provider.original_credentials,
|
||||
labels=labels
|
||||
))
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
@ -18,21 +20,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
list builtin tool provider tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_provider_configurations = ToolConfigurationManager(
|
||||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
@ -42,47 +48,47 @@ class BuiltinToolManageService:
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:return: the list of tool providers
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
])
|
||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
||||
):
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
"""
|
||||
update builtin tool provider
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f'provider {provider_name} does not need credentials')
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
@ -119,23 +125,25 @@ class BuiltinToolManageService:
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
@ -143,20 +151,22 @@ class BuiltinToolManageService:
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
delete tool provider
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
@ -165,48 +175,55 @@ class BuiltinToolManageService:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
):
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
with open(icon_path, 'rb') as f:
|
||||
with open(icon_path, "rb") as f:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
list builtin tools
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
find_provider = lambda provider: next(
|
||||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
decrypt_credentials=True
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
@ -214,16 +231,17 @@ class BuiltinToolManageService:
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels
|
||||
class ToolLabelsService:
|
||||
@classmethod
|
||||
def list_tool_labels(cls) -> list[ToolLabel]:
|
||||
return default_tool_labels
|
||||
return default_tool_labels
|
||||
|
||||
@ -11,13 +11,11 @@ class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.user_list_providers(
|
||||
user_id, tenant_id, typ
|
||||
)
|
||||
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
@ -26,4 +24,3 @@ class ToolCommonService:
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
return result
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from configs import dify_config
|
||||
@ -9,7 +8,6 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
@ -24,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/")
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
return url_prefix + 'builtin/' + provider_name + '/icon'
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return ''
|
||||
|
||||
@classmethod
|
||||
def repack_provider(cls, provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
repack provider
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
:param provider: the provider dict
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
if isinstance(provider, dict) and 'icon' in provider:
|
||||
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider['type'],
|
||||
provider_name=provider['name'],
|
||||
icon=provider['icon']
|
||||
repack provider
|
||||
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
if isinstance(provider, dict) and "icon" in provider:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.icon
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -95,14 +86,13 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels
|
||||
labels=provider_controller.tool_labels,
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = \
|
||||
ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@ -116,8 +106,7 @@ class ToolTransformService:
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
@ -127,10 +116,9 @@ class ToolTransformService:
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
cls,
|
||||
db_provider: ApiToolProvider,
|
||||
) -> ApiToolProviderController:
|
||||
"""
|
||||
@ -139,26 +127,23 @@ class ToolTransformService:
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
db_provider=db_provider,
|
||||
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
auth_type=ApiProviderAuthType.API_KEY
|
||||
if db_provider.credentials["auth_type"] == "api_key"
|
||||
else ApiProviderAuthType.NONE,
|
||||
)
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def workflow_provider_to_controller(
|
||||
cls,
|
||||
db_provider: WorkflowToolProvider
|
||||
) -> WorkflowToolProviderController:
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
"""
|
||||
convert provider controller to provider
|
||||
"""
|
||||
return WorkflowToolProviderController.from_db(db_provider)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] = None
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
@ -180,7 +165,7 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -189,16 +174,16 @@ class ToolTransformService:
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] = None
|
||||
labels: list[str] = None,
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = 'Anonymous'
|
||||
username = "Anonymous"
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
|
||||
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
@ -218,14 +203,13 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
@ -235,24 +219,25 @@ class ToolTransformService:
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
cls,
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tenant_id: str = None,
|
||||
labels: list[str] = None
|
||||
labels: list[str] = None,
|
||||
) -> UserTool:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
@ -277,25 +262,14 @@ class ToolTransformService:
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
parameters=current_parameters,
|
||||
labels=labels
|
||||
labels=labels,
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(
|
||||
en_US=tool.operation_id,
|
||||
zh_Hans=tool.operation_id
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US=tool.summary or '',
|
||||
zh_Hans=tool.summary or ''
|
||||
),
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
|
||||
for response in responses:
|
||||
yield response.model_dump()
|
||||
@ -19,10 +19,21 @@ class WorkflowToolManageService:
|
||||
"""
|
||||
Service class for managing workflow tools.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
|
||||
label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
def create_workflow_tool(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_app_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -32,32 +43,34 @@ class WorkflowToolManageService:
|
||||
:param description: the description
|
||||
:param parameters: the parameters
|
||||
:param privacy_policy: the privacy policy
|
||||
:param labels: labels
|
||||
:return: the created tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
).first()
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_app_id} not found')
|
||||
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_app_id}')
|
||||
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -75,58 +88,76 @@ class WorkflowToolManageService:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
db.session.add(workflow_tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
|
||||
name: str, label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
def update_workflow_tool(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_tool_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param tool: the tool
|
||||
:param workflow_tool_id: workflow tool id
|
||||
:param name: name
|
||||
:param label: label
|
||||
:param icon: icon
|
||||
:param description: description
|
||||
:param parameters: parameters
|
||||
:param privacy_policy: privacy policy
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id
|
||||
).first()
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} already exists')
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_tool_provider.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
|
||||
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
|
||||
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
@ -146,13 +177,10 @@ class WorkflowToolManageService:
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
|
||||
labels
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
@ -162,9 +190,7 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id
|
||||
).all()
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
tools = []
|
||||
for provider in db_tools:
|
||||
@ -180,14 +206,12 @@ class WorkflowToolManageService:
|
||||
|
||||
for tool in tools:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
@ -203,15 +227,12 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
@ -222,40 +243,37 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy,
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
"""
|
||||
@ -265,40 +283,37 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow_app_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_app_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
"""
|
||||
@ -308,19 +323,19 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
]
|
||||
|
||||
@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class VectorService:
|
||||
|
||||
@classmethod
|
||||
def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
|
||||
segments: list[DocumentSegment], dataset: Dataset):
|
||||
def create_segments_vector(
|
||||
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
|
||||
):
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
@ -20,14 +20,12 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# save vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
@ -50,13 +48,11 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# update vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
vector.add_texts([document], duplicate_check=True)
|
||||
|
||||
|
||||
@ -11,17 +11,29 @@ from services.conversation_service import ConversationService
|
||||
|
||||
class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by="-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
pinned_conversations = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).order_by(PinnedConversation.created_at.desc()).all()
|
||||
pinned_conversations = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.order_by(PinnedConversation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
|
||||
if pinned:
|
||||
include_ids = pinned_conversation_ids
|
||||
@ -36,31 +48,34 @@ class WebConversationService:
|
||||
invoke_from=invoke_from,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids,
|
||||
sort_by=sort_by,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if pinned_conversation:
|
||||
return
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=conversation_id,
|
||||
user=user
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
pinned_conversation = PinnedConversation(
|
||||
app_id=app_model.id,
|
||||
conversation_id=conversation.id,
|
||||
created_by_role='account' if isinstance(user, Account) else 'end_user',
|
||||
created_by=user.id
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
db.session.add(pinned_conversation)
|
||||
@ -68,12 +83,16 @@ class WebConversationService:
|
||||
|
||||
@classmethod
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pinned_conversation:
|
||||
return
|
||||
|
||||
@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'url' not in args or not args['url']:
|
||||
raise ValueError('url is required')
|
||||
if 'options' not in args or not args['options']:
|
||||
raise ValueError('options is required')
|
||||
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||
raise ValueError('limit is required')
|
||||
if "url" not in args or not args["url"]:
|
||||
raise ValueError("url is required")
|
||||
if "options" not in args or not args["options"]:
|
||||
raise ValueError("options is required")
|
||||
if "limit" not in args["options"] or not args["options"]["limit"]:
|
||||
raise ValueError("limit is required")
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get('provider')
|
||||
url = args.get('url')
|
||||
options = args.get('options')
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
options = args.get("options")
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||
only_main_content = options.get('only_main_content', False)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
only_main_content = options.get("only_main_content", False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"crawlerOptions": {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
"returnOnlyUrls": False,
|
||||
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
|
||||
}
|
||||
}
|
||||
else:
|
||||
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||
includes = options.get("includes").split(",") if options.get("includes") else []
|
||||
excludes = options.get("excludes").split(",") if options.get("excludes") else []
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"crawlerOptions": {
|
||||
"includes": includes if includes else [],
|
||||
"excludes": excludes if excludes else [],
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get('limit', 1),
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
"limit": options.get("limit", 1),
|
||||
"returnOnlyUrls": False,
|
||||
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
|
||||
}
|
||||
}
|
||||
if options.get('max_depth'):
|
||||
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||
if options.get("max_depth"):
|
||||
params["crawlerOptions"]["maxDepth"] = options.get("max_depth")
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {
|
||||
'status': 'active',
|
||||
'job_id': job_id
|
||||
}
|
||||
return {"status": "active", "job_id": job_id}
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
'status': result.get('status', 'active'),
|
||||
'job_id': job_id,
|
||||
'total': result.get('total', 0),
|
||||
'current': result.get('current', 0),
|
||||
'data': result.get('data', [])
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
}
|
||||
if crawl_status_data['status'] == 'completed':
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
else:
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get('status') != 'completed':
|
||||
raise ValueError('Crawl job is not completed')
|
||||
data = result.get('data')
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
data = result.get("data")
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get('source_url') == url:
|
||||
if item.get("source_url") == url:
|
||||
return item
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
params = {
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@ -6,7 +6,6 @@ from core.app.app_config.entities import (
|
||||
DatasetRetrieveConfigEntity,
|
||||
EasyUIBasedAppConfig,
|
||||
ExternalDataVariableEntity,
|
||||
FileExtraConfig,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
@ -14,6 +13,7 @@ from core.app.app_config.entities import (
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -32,11 +32,9 @@ class WorkflowConverter:
|
||||
App Convert to Workflow Mode
|
||||
"""
|
||||
|
||||
def convert_to_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
name: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def convert_to_workflow(
|
||||
self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
|
||||
):
|
||||
"""
|
||||
Convert app to workflow
|
||||
|
||||
@ -50,22 +48,24 @@ class WorkflowConverter:
|
||||
:param account: Account
|
||||
:param name: new app name
|
||||
:param icon: new app icon
|
||||
:param icon_type: new app icon type
|
||||
:param icon_background: new app icon background
|
||||
:return: new App instance
|
||||
"""
|
||||
# convert app model config
|
||||
if not app_model.app_model_config:
|
||||
raise ValueError("App model config is required")
|
||||
|
||||
workflow = self.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model.app_model_config,
|
||||
account_id=account.id
|
||||
app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
|
||||
)
|
||||
|
||||
# create new app
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name if name else app_model.name + '(workflow)'
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value \
|
||||
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.name = name if name else app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.icon_type = icon_type if icon_type else app_model.icon_type
|
||||
new_app.icon = icon if icon else app_model.icon
|
||||
new_app.icon_background = icon_background if icon_background else app_model.icon_background
|
||||
new_app.enable_site = app_model.enable_site
|
||||
@ -74,6 +74,8 @@ class WorkflowConverter:
|
||||
new_app.api_rph = app_model.api_rph
|
||||
new_app.is_demo = False
|
||||
new_app.is_public = app_model.is_public
|
||||
new_app.created_by = account.id
|
||||
new_app.updated_by = account.id
|
||||
db.session.add(new_app)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
@ -85,30 +87,21 @@ class WorkflowConverter:
|
||||
|
||||
return new_app
|
||||
|
||||
def convert_app_model_config_to_workflow(self, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
account_id: str) -> Workflow:
|
||||
def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
|
||||
"""
|
||||
Convert app model config to workflow mode
|
||||
:param app_model: App instance
|
||||
:param app_model_config: AppModelConfig instance
|
||||
:param account_id: Account ID
|
||||
:return:
|
||||
"""
|
||||
# get new app mode
|
||||
new_app_mode = self._get_new_app_mode(app_model)
|
||||
|
||||
# convert app model config
|
||||
app_config = self._convert_to_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
graph = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
@ -120,11 +113,9 @@ class WorkflowConverter:
|
||||
# - show_retrieve_source -> knowledge-retrieval
|
||||
|
||||
# convert to start node
|
||||
start_node = self._convert_to_start_node(
|
||||
variables=app_config.variables
|
||||
)
|
||||
start_node = self._convert_to_start_node(variables=app_config.variables)
|
||||
|
||||
graph['nodes'].append(start_node)
|
||||
graph["nodes"].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
external_data_variable_node_mapping = {}
|
||||
@ -132,7 +123,7 @@ class WorkflowConverter:
|
||||
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=app_config.variables,
|
||||
external_data_variables=app_config.external_data_variables
|
||||
external_data_variables=app_config.external_data_variables,
|
||||
)
|
||||
|
||||
for http_request_node in http_request_nodes:
|
||||
@ -141,9 +132,7 @@ class WorkflowConverter:
|
||||
# convert to knowledge retrieval node
|
||||
if app_config.dataset:
|
||||
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=app_config.dataset,
|
||||
model_config=app_config.model
|
||||
new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
|
||||
)
|
||||
|
||||
if knowledge_retrieval_node:
|
||||
@ -157,7 +146,7 @@ class WorkflowConverter:
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
@ -196,11 +185,12 @@ class WorkflowConverter:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(new_app_mode).value,
|
||||
version='draft',
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
@ -208,24 +198,18 @@ class WorkflowConverter:
|
||||
|
||||
return workflow
|
||||
|
||||
def _convert_to_app_config(self, app_model: App,
|
||||
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
@ -244,14 +228,13 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"variables": [jsonable_encoder(v) for v in variables]
|
||||
}
|
||||
"variables": [jsonable_encoder(v) for v in variables],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_http_request_node(self, app_model: App,
|
||||
variables: list[VariableEntity],
|
||||
external_data_variables: list[ExternalDataVariableEntity]) \
|
||||
-> tuple[list[dict], dict[str, str]]:
|
||||
def _convert_to_http_request_node(
|
||||
self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
|
||||
) -> tuple[list[dict], dict[str, str]]:
|
||||
"""
|
||||
Convert API Based Extension to HTTP Request Node
|
||||
:param app_model: App instance
|
||||
@ -273,40 +256,33 @@ class WorkflowConverter:
|
||||
|
||||
# get params from config
|
||||
api_based_extension_id = tool_config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
continue
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = self._get_api_based_extension(
|
||||
tenant_id=tenant_id,
|
||||
api_based_extension_id=api_based_extension_id
|
||||
tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(tool_variable))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
|
||||
|
||||
inputs = {}
|
||||
for v in variables:
|
||||
inputs[v.variable] = '{{#start.' + v.variable + '#}}'
|
||||
inputs[v.variable] = "{{#start." + v.variable + "#}}"
|
||||
|
||||
request_body = {
|
||||
'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
'params': {
|
||||
'app_id': app_model.id,
|
||||
'tool_variable': tool_variable,
|
||||
'inputs': inputs,
|
||||
'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
|
||||
}
|
||||
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
"params": {
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
"inputs": inputs,
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
|
||||
},
|
||||
}
|
||||
|
||||
request_body_json = json.dumps(request_body)
|
||||
request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
|
||||
request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
|
||||
|
||||
http_request_node = {
|
||||
"id": f"http_request_{index}",
|
||||
@ -316,20 +292,11 @@ class WorkflowConverter:
|
||||
"type": NodeType.HTTP_REQUEST.value,
|
||||
"method": "post",
|
||||
"url": api_based_extension.api_endpoint,
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": api_key
|
||||
}
|
||||
},
|
||||
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": request_body_json
|
||||
}
|
||||
}
|
||||
"body": {"type": "json", "data": request_body_json},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(http_request_node)
|
||||
@ -341,32 +308,24 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": f"Parse {api_based_extension.name} Response",
|
||||
"type": NodeType.CODE.value,
|
||||
"variables": [{
|
||||
"variable": "response_json",
|
||||
"value_selector": [http_request_node['id'], "body"]
|
||||
}],
|
||||
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
|
||||
"code_language": "python3",
|
||||
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
|
||||
"response_json)\n return {\n \"result\": response_body[\"result\"]\n }",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
'response_json)\n return {\n "result": response_body["result"]\n }',
|
||||
"outputs": {"result": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(code_node)
|
||||
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
|
||||
index += 1
|
||||
|
||||
return nodes, external_data_variable_node_mapping
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
|
||||
dataset_config: DatasetEntity,
|
||||
model_config: ModelConfigEntity) \
|
||||
-> Optional[dict]:
|
||||
def _convert_to_knowledge_retrieval_node(
|
||||
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
@ -400,7 +359,7 @@ class WorkflowConverter:
|
||||
"completion_params": {
|
||||
**model_config.parameters,
|
||||
"stop": model_config.stop,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
@ -408,20 +367,23 @@ class WorkflowConverter:
|
||||
"multiple_retrieval_config": {
|
||||
"top_k": retrieve_config.top_k,
|
||||
"score_threshold": retrieve_config.score_threshold,
|
||||
"reranking_model": retrieve_config.reranking_model
|
||||
"reranking_model": retrieve_config.reranking_model,
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
|
||||
else None,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_llm_node(self, original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> dict:
|
||||
def _convert_to_llm_node(
|
||||
self,
|
||||
original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@ -433,17 +395,18 @@ class WorkflowConverter:
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
"""
|
||||
# fetch start and knowledge retrieval node
|
||||
start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
|
||||
knowledge_retrieval_node = next(filter(
|
||||
lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
graph['nodes']
|
||||
), None)
|
||||
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
|
||||
knowledge_retrieval_node = next(
|
||||
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -452,45 +415,35 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"role": 'user',
|
||||
"text": template
|
||||
}
|
||||
]
|
||||
prompts = [{"role": "user", "text": template}]
|
||||
else:
|
||||
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
|
||||
|
||||
prompts = []
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
if advanced_chat_prompt_template:
|
||||
if advanced_chat_prompt_template:
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
text = m.text
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
text, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts.append({
|
||||
"role": m.role.value,
|
||||
"text": text
|
||||
})
|
||||
prompts.append({"role": m.role.value, "text": text})
|
||||
# Completion Model
|
||||
else:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -499,57 +452,50 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=template,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
prompts = {
|
||||
"text": template
|
||||
}
|
||||
prompts = {"text": template}
|
||||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
role_prefix = {
|
||||
"user": prompt_rules.get('human_prefix', 'Human'),
|
||||
"assistant": prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
"user": prompt_rules.get("human_prefix", "Human"),
|
||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
if advanced_completion_prompt_template:
|
||||
text = advanced_completion_prompt_template.prompt
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=text,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
text = text.replace('{{#query#}}', '{{#sys.query#}}')
|
||||
text = text.replace("{{#query#}}", "{{#sys.query#}}")
|
||||
|
||||
prompts = {
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
|
||||
role_prefix = {
|
||||
"user": advanced_completion_prompt_template.role_prefix.user,
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant,
|
||||
}
|
||||
|
||||
memory = None
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
memory = {
|
||||
"role_prefix": role_prefix,
|
||||
"window": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
|
||||
|
||||
completion_params = model_config.parameters
|
||||
completion_params.update({"stop": model_config.stop})
|
||||
@ -563,41 +509,42 @@ class WorkflowConverter:
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
"mode": model_config.mode,
|
||||
"completion_params": completion_params
|
||||
"completion_params": completion_params,
|
||||
},
|
||||
"prompt_template": prompts,
|
||||
"memory": memory,
|
||||
"context": {
|
||||
"enabled": knowledge_retrieval_node is not None,
|
||||
"variable_selector": ["knowledge_retrieval", "result"]
|
||||
if knowledge_retrieval_node is not None else None
|
||||
if knowledge_retrieval_node is not None
|
||||
else None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": file_upload is not None,
|
||||
"variable_selector": ["sys", "files"] if file_upload is not None else None,
|
||||
"configs": {
|
||||
"detail": file_upload.image_config['detail']
|
||||
} if file_upload is not None else None
|
||||
}
|
||||
}
|
||||
"configs": {"detail": file_upload.image_config["detail"]}
|
||||
if file_upload is not None and file_upload.image_config is not None
|
||||
else None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _replace_template_variables(self, template: str,
|
||||
variables: list[dict],
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> str:
|
||||
def _replace_template_variables(
|
||||
self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Replace Template Variables
|
||||
:param template: template
|
||||
:param variables: list of variables
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
:return:
|
||||
"""
|
||||
for v in variables:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
|
||||
|
||||
if external_data_variable_node_mapping:
|
||||
for variable, code_node_id in external_data_variable_node_mapping.items():
|
||||
template = template.replace('{{' + variable + '}}',
|
||||
'{{#' + code_node_id + '.result#}}')
|
||||
template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
|
||||
|
||||
return template
|
||||
|
||||
@ -613,11 +560,8 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": [{
|
||||
"variable": "result",
|
||||
"value_selector": ["llm", "text"]
|
||||
}]
|
||||
}
|
||||
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
@ -629,11 +573,7 @@ class WorkflowConverter:
|
||||
return {
|
||||
"id": "answer",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "ANSWER",
|
||||
"type": NodeType.ANSWER.value,
|
||||
"answer": "{{#llm.text#}}"
|
||||
}
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
@ -643,11 +583,7 @@ class WorkflowConverter:
|
||||
:param target: target node id
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": f"{source}-{target}",
|
||||
"source": source,
|
||||
"target": target
|
||||
}
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
"""
|
||||
@ -657,9 +593,9 @@ class WorkflowConverter:
|
||||
:param node: Node to append
|
||||
:return:
|
||||
"""
|
||||
previous_node = graph['nodes'][-1]
|
||||
graph['nodes'].append(node)
|
||||
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
|
||||
previous_node = graph["nodes"][-1]
|
||||
graph["nodes"].append(node)
|
||||
graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
|
||||
return graph
|
||||
|
||||
def _get_new_app_mode(self, app_model: App) -> AppMode:
|
||||
@ -673,14 +609,20 @@ class WorkflowConverter:
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
|
||||
"""
|
||||
Get API Based Extension
|
||||
:param tenant_id: tenant id
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
|
||||
|
||||
return api_based_extension
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
@ -8,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
|
||||
def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
|
||||
"""
|
||||
Get paginate workflow app logs
|
||||
@ -16,47 +17,51 @@ class WorkflowAppService:
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
query = (
|
||||
db.select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id,
|
||||
WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
query = db.select(WorkflowAppLog).where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
|
||||
status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None
|
||||
if args['keyword'] or status:
|
||||
query = query.join(
|
||||
WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id
|
||||
)
|
||||
status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None
|
||||
keyword = args["keyword"]
|
||||
if keyword or status:
|
||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
|
||||
if args['keyword']:
|
||||
keyword_val = f"%{args['keyword'][:30]}%"
|
||||
if keyword:
|
||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_val),
|
||||
WorkflowRun.outputs.ilike(keyword_val),
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
# filter keyword by end user session id if created by end user role
|
||||
and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val))
|
||||
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
|
||||
]
|
||||
|
||||
# filter keyword by workflow run id
|
||||
keyword_uuid = self._safe_parse_uuid(keyword)
|
||||
if keyword_uuid:
|
||||
keyword_conditions.append(WorkflowRun.id == keyword_uuid)
|
||||
|
||||
query = query.outerjoin(
|
||||
EndUser,
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value)
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value),
|
||||
).filter(or_(*keyword_conditions))
|
||||
|
||||
if status:
|
||||
# join with workflow_run and filter by status
|
||||
query = query.filter(
|
||||
WorkflowRun.status == status.value
|
||||
)
|
||||
query = query.filter(WorkflowRun.status == status.value)
|
||||
|
||||
query = query.order_by(WorkflowAppLog.created_at.desc())
|
||||
|
||||
pagination = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return pagination
|
||||
|
||||
@staticmethod
|
||||
def _safe_parse_uuid(value: str):
|
||||
# fast check
|
||||
if len(value) < 32:
|
||||
return None
|
||||
|
||||
try:
|
||||
return uuid.UUID(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@ -18,6 +18,7 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
|
||||
class WorkflowWithMessage:
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
@ -33,9 +34,7 @@ class WorkflowRunService:
|
||||
with_message_workflow_runs = []
|
||||
for workflow_run in pagination.data:
|
||||
message = workflow_run.message
|
||||
with_message_workflow_run = WorkflowWithMessage(
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
|
||||
if message:
|
||||
with_message_workflow_run.message_id = message.id
|
||||
with_message_workflow_run.conversation_id = message.conversation_id
|
||||
@ -53,26 +52,30 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
limit = int(args.get('limit', 20))
|
||||
limit = int(args.get("limit", 20))
|
||||
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
)
|
||||
|
||||
if args.get('last_id'):
|
||||
if args.get("last_id"):
|
||||
last_workflow_run = base_query.filter(
|
||||
WorkflowRun.id == args.get('last_id'),
|
||||
WorkflowRun.id == args.get("last_id"),
|
||||
).first()
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError('Last workflow run not exists')
|
||||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
workflow_runs = base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at,
|
||||
WorkflowRun.id != last_workflow_run.id
|
||||
).order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
workflow_runs = (
|
||||
base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
|
||||
)
|
||||
.order_by(WorkflowRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
|
||||
@ -81,17 +84,13 @@ class WorkflowRunService:
|
||||
current_page_first_workflow_run = workflow_runs[-1]
|
||||
rest_count = base_query.filter(
|
||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
||||
WorkflowRun.id != current_page_first_workflow_run.id
|
||||
WorkflowRun.id != current_page_first_workflow_run.id,
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=workflow_runs,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
@ -100,11 +99,15 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param run_id: workflow run id
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
).first()
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun)
|
||||
.filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
@ -117,12 +120,17 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
).order_by(WorkflowNodeExecution.index.desc()).all()
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return node_executions
|
||||
|
||||
@ -37,11 +37,13 @@ class WorkflowService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == 'draft'
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
@ -55,11 +57,15 @@ class WorkflowService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
@ -72,6 +78,7 @@ class WorkflowService:
|
||||
unique_hash: Optional[str],
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
@ -84,10 +91,7 @@ class WorkflowService:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
app_model=app_model,
|
||||
features=features
|
||||
)
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
@ -95,11 +99,12 @@ class WorkflowService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_model.mode).value,
|
||||
version='draft',
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
@ -109,6 +114,7 @@ class WorkflowService:
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
@ -119,9 +125,7 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
"""
|
||||
Publish workflow from draft
|
||||
|
||||
@ -134,7 +138,7 @@ class WorkflowService:
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not draft_workflow:
|
||||
raise ValueError('No valid workflow found.')
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow(
|
||||
@ -145,7 +149,8 @@ class WorkflowService:
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
@ -183,17 +188,16 @@ class WorkflowService:
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_config(node_type, filters)
|
||||
|
||||
def run_draft_workflow_node(self, app_model: App,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
account: Account) -> WorkflowNodeExecution:
|
||||
def run_draft_workflow_node(
|
||||
self, app_model: App, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
@ -222,7 +226,7 @@ class WorkflowService:
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
@ -243,14 +247,15 @@ class WorkflowService:
|
||||
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
|
||||
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
|
||||
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
|
||||
execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
|
||||
if node_run_result.metadata else None),
|
||||
execution_metadata=(
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
),
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
else:
|
||||
# create workflow node execution
|
||||
@ -269,7 +274,7 @@ class WorkflowService:
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
@ -291,15 +296,16 @@ class WorkflowService:
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]:
|
||||
raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.')
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get('name'),
|
||||
icon=args.get('icon'),
|
||||
icon_background=args.get('icon_background'),
|
||||
name=args.get("name"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
|
||||
return new_app
|
||||
@ -307,15 +313,33 @@ class WorkflowService:
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def get_elapsed_time(cls, workflow_run_id: str) -> float:
|
||||
"""
|
||||
Get elapsed time
|
||||
"""
|
||||
elapsed_time = 0.0
|
||||
|
||||
# fetch workflow node execution by workflow_run_id
|
||||
workflow_nodes = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowNodeExecution.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
if not workflow_nodes:
|
||||
return elapsed_time
|
||||
|
||||
for node in workflow_nodes:
|
||||
elapsed_time += node.elapsed_time
|
||||
|
||||
return elapsed_time
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
@ -14,34 +13,40 @@ class WorkspaceService:
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_info = {
|
||||
'id': tenant.id,
|
||||
'name': tenant.name,
|
||||
'plan': tenant.plan,
|
||||
'status': tenant.status,
|
||||
'created_at': tenant.created_at,
|
||||
'in_trail': True,
|
||||
'trial_end_reason': None,
|
||||
'role': 'normal',
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"in_trail": True,
|
||||
"trial_end_reason": None,
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
# Get role of user
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
tenant_info['role'] = tenant_account_join.role
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant,
|
||||
[TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
|
||||
if can_replace_logo and TenantService.has_roles(
|
||||
tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
|
||||
):
|
||||
base_url = dify_config.FILES_URL
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = (
|
||||
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
|
||||
if tenant.custom_config_dict.get("replace_webapp_logo")
|
||||
else None
|
||||
)
|
||||
remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False)
|
||||
|
||||
tenant_info['custom_config'] = {
|
||||
'remove_webapp_brand': remove_webapp_brand,
|
||||
'replace_webapp_logo': replace_webapp_logo,
|
||||
tenant_info["custom_config"] = {
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
return tenant_info
|
||||
|
||||
Reference in New Issue
Block a user