Compare commits

..

2 Commits

6 changed files with 199 additions and 41 deletions

View File

@ -3,6 +3,7 @@ from urllib import parse
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
import services
from configs import dify_config
@ -21,15 +22,15 @@ from controllers.console.auth.error import (
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@ -78,6 +79,54 @@ def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
return list(dict.fromkeys(email.lower() for email in emails))
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
new_member_count = 0
for email in emails:
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
new_member_count += 1
continue
exists = db.session.scalar(
select(TenantAccountJoin.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not exists:
new_member_count += 1
return new_member_count
def _count_current_members(tenant_id: str) -> int:
return (
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
)
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
if new_member_count <= 0:
return
features = FeatureService.get_features(tenant_id=tenant_id)
if dify_config.ENTERPRISE_ENABLED:
workspace_members = features.workspace_members
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
raise WorkspaceMembersLimitExceeded()
return
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
members = features.members
current_member_count = _count_current_members(tenant_id)
if 0 < members.limit < current_member_count + new_member_count:
raise WorkspaceMembersLimitExceeded()
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@ -104,12 +153,11 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args.emails
invitee_emails = _normalize_invitee_emails(args.emails)
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
@ -129,37 +177,36 @@ class MemberInviteEmailApi(Resource):
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
tenant_id = inviter.current_tenant.id
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
_check_member_invite_limits(tenant_id, new_member_count)
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
"result": "success",

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import patch
@ -18,7 +19,7 @@ def app():
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
workspace_members = SimpleNamespace(enabled=False, is_available=lambda count: True)
return SimpleNamespace(
billing=SimpleNamespace(enabled=False),
workspace_members=workspace_members,
@ -31,6 +32,11 @@ def _build_feature_flags():
class TestMemberInviteEmailApi:
@pytest.fixture(autouse=True)
def _mock_member_invite_lock(self):
with patch("controllers.console.workspace.members.redis_client.lock", return_value=nullcontext()):
yield
@patch("controllers.console.workspace.members.FeatureService.get_features")
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
@patch("controllers.console.workspace.members.current_account_with_tenant")
@ -52,7 +58,12 @@ class TestMemberInviteEmailApi:
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
mock_current_account.return_value = (inviter, tenant.id)
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
with (
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
with app.test_request_context(
"/workspaces/current/members/invite-email",
method="POST",
@ -70,7 +81,7 @@ class TestMemberInviteEmailApi:
assert mock_invite_member.call_count == 1
call_args = mock_invite_member.call_args
assert call_args.kwargs["tenant"] == tenant
assert call_args.kwargs["email"] == "User@Example.com"
assert call_args.kwargs["email"] == "user@example.com"
assert call_args.kwargs["language"] == "en-US"
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
assert call_args.kwargs["inviter"] == inviter

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
import pytest
@ -75,6 +76,11 @@ class TestMemberListApi:
class TestMemberInviteEmailApi:
@pytest.fixture(autouse=True)
def _mock_member_invite_lock(self):
with patch("controllers.console.workspace.members.redis_client.lock", return_value=nullcontext()):
yield
def test_invite_success(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -82,6 +88,8 @@ class TestMemberInviteEmailApi:
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.billing.enabled = False
features.workspace_members.enabled = False
features.workspace_members.is_available.return_value = True
payload = {
@ -94,8 +102,11 @@ class TestMemberInviteEmailApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, status = method(api)
@ -109,6 +120,8 @@ class TestMemberInviteEmailApi:
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.billing.enabled = False
features.workspace_members.enabled = True
features.workspace_members.is_available.return_value = False
payload = {
@ -120,6 +133,38 @@ class TestMemberInviteEmailApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_billing_limit_exceeded(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.billing.enabled = True
features.members.size = 9
features.members.limit = 10
features.workspace_members.enabled = False
payload = {
"emails": ["a@test.com", "b@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=2),
patch("controllers.console.workspace.members._count_current_members", return_value=9),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", True),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
@ -131,6 +176,8 @@ class TestMemberInviteEmailApi:
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.billing.enabled = False
features.workspace_members.enabled = False
features.workspace_members.is_available.return_value = True
payload = {
@ -142,11 +189,14 @@ class TestMemberInviteEmailApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=0),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=AccountAlreadyInTenantError(),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, status = method(api)
@ -174,6 +224,8 @@ class TestMemberInviteEmailApi:
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.billing.enabled = False
features.workspace_members.enabled = False
features.workspace_members.is_available.return_value = True
payload = {
@ -185,11 +237,14 @@ class TestMemberInviteEmailApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=Exception("boom"),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, _ = method(api)

View File

@ -95,6 +95,7 @@ export type CurrentPlanInfoBackend = {
}
webapp_copyright_enabled: boolean
workspace_members: {
enabled?: boolean
size: number
limit: number
}

View File

@ -28,6 +28,37 @@ type ProviderContextProviderProps = {
children: ReactNode
}
type MemberInviteLimit = {
size: number
limit: number
}
const unlimitedMemberInviteLimit: MemberInviteLimit = {
size: 0,
limit: 0,
}
const resolveMemberInviteLimit = (data: Awaited<ReturnType<typeof fetchCurrentPlanInfo>>): MemberInviteLimit => {
if (!data)
return unlimitedMemberInviteLimit
if (data.workspace_members?.enabled) {
return {
size: data.workspace_members.size,
limit: data.workspace_members.limit,
}
}
if (data.billing?.enabled && data.members?.limit > 0) {
return {
size: data.members.size,
limit: data.members.limit,
}
}
return unlimitedMemberInviteLimit
}
export const ProviderContextProvider = ({
children,
}: ProviderContextProviderProps) => {
@ -87,8 +118,7 @@ export const ProviderContextProvider = ({
setDatasetOperatorEnabled(true)
if (data.webapp_copyright_enabled)
setWebappCopyrightEnabled(true)
if (data.workspace_members)
setLicenseLimit({ workspace_members: data.workspace_members })
setLicenseLimit({ workspace_members: resolveMemberInviteLimit(data) })
if (data.is_allow_transfer_workspace)
setIsAllowTransferWorkspace(data.is_allow_transfer_workspace)
if (data.knowledge_pipeline?.publish_enabled)

View File

@ -1,8 +1,22 @@
import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type'
import { get } from './base'
export const fetchCurrentPlanInfo = () => {
return get<CurrentPlanInfoBackend>('/features')
type CurrentPlanInfoResponse = Omit<CurrentPlanInfoBackend, 'workspace_members'> & {
workspace_members: Omit<CurrentPlanInfoBackend['workspace_members'], 'size'> & {
size: number | null
}
}
export const fetchCurrentPlanInfo = async (): Promise<CurrentPlanInfoBackend> => {
const data = await get<CurrentPlanInfoResponse>('/features')
return {
...data,
workspace_members: {
...data.workspace_members,
size: data.workspace_members.size ?? 0,
},
}
}
export const fetchSubscriptionUrls = (plan: string, interval: string) => {