Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-12-31 16:47:56 +08:00
93 changed files with 1717 additions and 911 deletions

View File

@ -33,6 +33,7 @@ from models.account import (
TenantStatus,
)
from models.model import DifySetup
from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
@ -51,6 +52,8 @@ from services.errors.account import (
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
@ -71,6 +74,9 @@ class AccountService:
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5
@staticmethod
@ -202,6 +208,15 @@ class AccountService:
from controllers.console.error import AccountNotFound
raise AccountNotFound()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
"30 days and is temporarily unavailable for new account registration"
)
)
account = Account()
account.email = email
account.name = name
@ -241,6 +256,42 @@ class AccountService:
return account
@staticmethod
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, token_type="account_deletion", additional_data={"code": code}
)
return token, code
@classmethod
def send_account_deletion_verification_email(cls, account: Account, code: str):
email = account.email
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
send_account_deletion_verification_code.delay(to=email, code=code)
cls.email_code_account_deletion_rate_limiter.increment_rate_limit(email)
@staticmethod
def verify_account_deletion_code(token: str, code: str) -> bool:
token_data = TokenManager.get_token_data(token, "account_deletion")
if token_data is None:
return False
if token_data["code"] != code:
return False
return True
@staticmethod
def delete_account(account: Account) -> None:
"""Delete account. This method only adds a task to the queue for deletion."""
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate"""
@ -380,6 +431,7 @@ class AccountService:
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email):
@ -409,6 +461,14 @@ class AccountService:
@classmethod
def get_user_through_email(cls, email: str):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
"30 days and is temporarily unavailable for new account registration"
)
)
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
@ -825,6 +885,10 @@ class RegisterService:
db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
except AccountRegisterError as are:
db.session.rollback()
logging.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
logging.exception("Register failed")

View File

@ -139,7 +139,7 @@ class AudioService:
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
if not text:
if text is None:
raise ValueError("Text is required")
response = invoke_tts(text, app_model, voice)
if isinstance(response, Generator):

View File

@ -70,3 +70,24 @@ class BillingService:
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError("Only team owner or team admin can perform this action")
@classmethod
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
return cls._send_request("DELETE", "/account/", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:
params = {"email": email}
try:
response = cls._send_request("GET", "/account/in-freeze", params=params)
return bool(response.get("data", False))
except Exception:
return False
@classmethod
def update_account_deletion_feedback(cls, email: str, feedback: str):
"""Update account deletion feedback."""
json = {"email": email, "feedback": feedback}
return cls._send_request("POST", "/account/delete-feedback", json=json)

View File

@ -86,25 +86,30 @@ class DatasetService:
else:
return [], 0
else:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id),
db.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids),
),
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
db.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids),
),
)
)
)
else:
query = query.filter(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id),
else:
query = query.filter(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
)
)
)
else:
# if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
@ -377,14 +382,19 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
if (
not user_permission
and dataset.tenant_id != user.current_tenant_id
and dataset.created_by != user.id
):
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
@ -394,15 +404,16 @@ class DatasetService:
if not user:
raise ValueError("User not found")
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
@ -441,7 +452,7 @@ class DatasetService:
class DocumentService:
DEFAULT_RULES = {
DEFAULT_RULES: dict[str, Any] = {
"mode": "custom",
"rules": {
"pre_processing_rules": [
@ -455,7 +466,7 @@ class DocumentService:
},
}
DOCUMENT_METADATA_SCHEMA = {
DOCUMENT_METADATA_SCHEMA: dict[str, Any] = {
"book": {
"title": str,
"language": str,

View File

@ -439,7 +439,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
)
)
result = tool.validate_credentials(credentials, parameters)
result = runtime_tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}

View File

@ -5,6 +5,8 @@ from datetime import UTC, datetime
from typing import Any, Optional
from uuid import uuid4
from sqlalchemy import desc
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
@ -77,6 +79,28 @@ class WorkflowService:
return workflow
def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]:
"""
Get published workflow with pagination
"""
if not app_model.workflow_id:
return [], False
workflows = (
db.session.query(Workflow)
.filter(Workflow.app_id == app_model.id)
.order_by(desc(Workflow.version))
.offset((page - 1) * limit)
.limit(limit + 1)
.all()
)
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
def sync_draft_workflow(
self,
*,