diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 463d468698..9bd2637007 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -282,76 +282,132 @@ class DeleteExploreBannerApi(Resource): return {"result": "success"}, 204 -class SaveNotificationContentPayload(BaseModel): - content: str = Field(...) +class LangContentPayload(BaseModel): + lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'") + title: str = Field(...) + body: str = Field(...) + cta_label: str = Field(...) + cta_url: str = Field(...) -class SaveNotificationUserPayload(BaseModel): - user_email: list[str] = Field(...) +class UpsertNotificationPayload(BaseModel): + notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update") + contents: list[LangContentPayload] = Field(..., min_length=1) + start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z") + end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z") + frequency: str = Field(default="once", description="'once' | 'every_page_load'") + status: str = Field(default="active", description="'active' | 'inactive'") + + +class BatchAddNotificationAccountsPayload(BaseModel): + notification_id: str = Field(...) + user_email: list[str] = Field(..., description="List of account email addresses") console_ns.schema_model( - SaveNotificationContentPayload.__name__, - SaveNotificationContentPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), + UpsertNotificationPayload.__name__, + UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) console_ns.schema_model( - SaveNotificationUserPayload.__name__, - SaveNotificationUserPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), + BatchAddNotificationAccountsPayload.__name__, + BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) -@console_ns.route("/admin/save_notification_content") -class SaveNotificationContentApi(Resource): - @console_ns.doc("save_notification_content") - @console_ns.doc(description="Save a notification content") - @console_ns.expect(console_ns.models[SaveNotificationContentPayload.__name__]) - @console_ns.response(200, "Notification content saved successfully") - @only_edition_cloud - @admin_required - def post(self): - payload = SaveNotificationContentPayload.model_validate(console_ns.payload) - BillingService.save_notification_content(payload.content) - return {"result": "success"}, 200 - - -@console_ns.route("/admin/save_notification_user") -class SaveNotificationUserApi(Resource): - @console_ns.doc("save_notification_user") +@console_ns.route("/admin/upsert_notification") +class UpsertNotificationApi(Resource): + @console_ns.doc("upsert_notification") @console_ns.doc( - description="Save notification users via JSON body or file upload. " - 'JSON: {"user_email": ["a@example.com", ...]}. ' - "File: multipart/form-data with a 'file' field (CSV or TXT, one email per line)." + description=( + "Create or update an in-product notification. " + "Supply notification_id to update an existing one; omit it to create a new one. " + "Pass at least one language variant in contents (zh / en / jp)." + ) ) - @console_ns.response(200, "Notification users saved successfully") + @console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__]) + @console_ns.response(200, "Notification upserted successfully") @only_edition_cloud @admin_required def post(self): - # Determine input mode: file upload or JSON body + payload = UpsertNotificationPayload.model_validate(console_ns.payload) + result = BillingService.upsert_notification( + contents=[c.model_dump() for c in payload.contents], + frequency=payload.frequency, + status=payload.status, + notification_id=payload.notification_id, + start_time=payload.start_time, + end_time=payload.end_time, + ) + return {"result": "success", "notification_id": result.get("notificationId")}, 200 + + +@console_ns.route("/admin/batch_add_notification_accounts") +class BatchAddNotificationAccountsApi(Resource): + @console_ns.doc("batch_add_notification_accounts") + @console_ns.doc( + description=( + "Register target accounts for a notification by email address. " + 'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. ' + "File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) " + "plus a 'notification_id' field. " + "Emails that do not match any account are silently skipped." + ) + ) + @console_ns.response(200, "Accounts added successfully") + @only_edition_cloud + @admin_required + def post(self): + from models.account import Account + if "file" in request.files: + notification_id = request.form.get("notification_id", "").strip() + if not notification_id: + raise BadRequest("notification_id is required.") emails = self._parse_emails_from_file() else: - payload = SaveNotificationUserPayload.model_validate(console_ns.payload) + payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload) + notification_id = payload.notification_id emails = payload.user_email if not emails: raise BadRequest("No valid email addresses provided.") - # Use batch API for bulk insert (chunks of 1000 per request to billing service) - result = BillingService.save_notification_users_batch(emails) + # Resolve emails → account IDs in chunks to avoid large IN-clause + account_ids: list[str] = [] + chunk_size = 500 + for i in range(0, len(emails), chunk_size): + chunk = emails[i : i + chunk_size] + rows = db.session.execute( + select(Account.id, Account.email).where(Account.email.in_(chunk)) + ).all() + account_ids.extend(str(row.id) for row in rows) + + if not account_ids: + raise BadRequest("None of the provided emails matched an existing account.") + + # Send to dify-saas in batches of 1000 + total_count = 0 + batch_size = 1000 + for i in range(0, len(account_ids), batch_size): + batch = account_ids[i : i + batch_size] + result = BillingService.batch_add_notification_accounts( + notification_id=notification_id, + account_ids=batch, + ) + total_count += result.get("count", 0) return { "result": "success", - "total": len(emails), - "succeeded": result["succeeded"], - "failed_chunks": result["failed_chunks"], + "emails_provided": len(emails), + "accounts_matched": len(account_ids), + "count": total_count, }, 200 @staticmethod def _parse_emails_from_file() -> list[str]: """Parse email addresses from an uploaded CSV or TXT file.""" file = request.files["file"] - if not file.filename: raise BadRequest("Uploaded file has no filename.") @@ -359,7 +415,6 @@ class SaveNotificationUserApi(Resource): if not filename_lower.endswith((".csv", ".txt")): raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") - # Read file content try: content = file.read().decode("utf-8") except UnicodeDecodeError: @@ -375,20 +430,20 @@ class SaveNotificationUserApi(Resource): for row in reader: for cell in row: cell = cell.strip() - emails.append(cell) + if cell: + emails.append(cell) else: - # TXT file: one email per line for line in content.splitlines(): line = line.strip() - emails.append(line) + if line: + emails.append(line) # Deduplicate while preserving order seen: set[str] = set() unique_emails: list[str] = [] for email in emails: - email_lower = email.lower() - if email_lower not in seen: - seen.add(email_lower) + if email.lower() not in seen: + seen.add(email.lower()) unique_emails.append(email) return unique_emails diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py index dd3feba974..186d93ef2e 100644 --- a/api/controllers/console/notification.py +++ b/api/controllers/console/notification.py @@ -5,16 +5,47 @@ from controllers.console.wraps import account_initialization_required, only_edit from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService +# Notification content is stored under three lang tags. +_FALLBACK_LANG = "en" + +# Maps dify interface_language prefixes to notification lang tags. +# Any unrecognised prefix falls back to _FALLBACK_LANG. +_LANG_MAP: dict[str, str] = { + "zh": "zh", + "ja": "jp", +} + + +def _resolve_lang(interface_language: str | None) -> str: + """Derive the notification lang tag from the user's interface_language. + + e.g. "zh-Hans" → "zh", "ja-JP" → "jp", "en-US" / None → "en" + """ + if not interface_language: + return _FALLBACK_LANG + prefix = interface_language.split("-")[0].lower() + return _LANG_MAP.get(prefix, _FALLBACK_LANG) + + +def _pick_lang_content(contents: dict, lang: str) -> dict: + """Return the single LangContent for *lang*, falling back to English.""" + return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + @console_ns.route("/notification") class NotificationApi(Resource): @console_ns.doc("get_notification") - @console_ns.doc(description="Get notification for the current user") @console_ns.doc( + description=( + "Return the active in-product notification for the current user " + "in their interface language (falls back to English if unavailable). " + "Calling this endpoint also marks the notification as seen; subsequent " + "calls return should_show=false when frequency='once'." + ), responses={ - 200: "Success", + 200: "Success — inspect should_show to decide whether to render the modal", 401: "Unauthorized", - } + }, ) @setup_required @login_required @@ -22,5 +53,28 @@ class NotificationApi(Resource): @only_edition_cloud def get(self): current_user, _ = current_account_with_tenant() - notification = BillingService.read_notification(current_user.email) - return notification + + result = BillingService.get_account_notification(str(current_user.id)) + + # Proto JSON uses camelCase field names (Kratos default marshaling). + if not result.get("shouldShow"): + return {"should_show": False}, 200 + + notification = result.get("notification") or {} + contents: dict = notification.get("contents") or {} + + lang = _resolve_lang(current_user.interface_language) + lang_content = _pick_lang_content(contents, lang) + + return { + "should_show": True, + "notification": { + "notification_id": notification.get("notificationId"), + "frequency": notification.get("frequency"), + "lang": lang_content.get("lang", lang), + "title": lang_content.get("title", ""), + "body": lang_content.get("body", ""), + "cta_label": lang_content.get("ctaLabel", ""), + "cta_url": lang_content.get("ctaUrl", ""), + }, + }, 200 diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 4b0d700d50..b23df5832d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -394,34 +394,66 @@ class BillingService: tenant_whitelist.append(item["tenant_id"]) return tenant_whitelist - @classmethod - def read_notification(cls, user_email: str): - params = {"user_email": user_email} - return cls._send_request("GET", "/notification/read", params=params) @classmethod - def save_notification_user(cls, user_email: str): - json = {"user_email": user_email} - return cls._send_request("POST", "/notification/new-notification-user", json=json) + def get_account_notification(cls, account_id: str) -> dict: + """Return the active in-product notification for account_id, if any. + + Calling this endpoint also marks the notification as seen; subsequent + calls will return should_show=false when frequency='once'. + + Response shape (mirrors GetAccountNotificationReply): + { + "should_show": bool, + "notification": { # present only when should_show=true + "notification_id": str, + "contents": { # lang -> LangContent + "en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...}, + ... + }, + "frequency": "once" | "every_page_load" + } + } + """ + return cls._send_request("GET", "/notifications/active", params={"account_id": account_id}) @classmethod - def save_notification_users_batch(cls, user_emails: list[str]) -> dict: - """Batch save notification users in chunks of 1000.""" - chunk_size = 1000 - total_succeeded = 0 - failed_chunks: list[dict] = [] + def upsert_notification( + cls, + contents: list[dict], + frequency: str = "once", + status: str = "active", + notification_id: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + ) -> dict: + """Create or update a notification. - for i in range(0, len(user_emails), chunk_size): - chunk = user_emails[i : i + chunk_size] - try: - resp = cls._send_request("POST", "/notification/batch-notification-users", json={"user_emails": chunk}) - total_succeeded += resp.get("count", len(chunk)) - except Exception as e: - failed_chunks.append({"offset": i, "count": len(chunk), "error": str(e)}) - - return {"succeeded": total_succeeded, "failed_chunks": failed_chunks} + contents: list of {"lang": str, "title": str, "body": str, "cta_label": str, "cta_url": str} + start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. + Returns {"notification_id": str}. + """ + payload: dict = { + "contents": contents, + "frequency": frequency, + "status": status, + } + if notification_id: + payload["notification_id"] = notification_id + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + return cls._send_request("POST", "/notifications", json=payload) @classmethod - def save_notification_content(cls, content: str): - json = {"content": content} - return cls._send_request("POST", "/notification/new-notification", json=json) + def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict: + """Register target account IDs for a notification (max 1000 per call). + + Returns {"count": int}. + """ + return cls._send_request( + "POST", + f"/notifications/{notification_id}/accounts", + json={"account_ids": account_ids}, + )