mirror of
https://github.com/langgenius/dify.git
synced 2026-04-07 06:07:18 +08:00
Compare commits
44 Commits
deploy/dev
...
3-18-no-gl
| Author | SHA1 | Date | |
|---|---|---|---|
| 01fe3106f5 | |||
| eabaffbf56 | |||
| d032e236ad | |||
| 09ca243f0d | |||
| 799e754b10 | |||
| a8fe91438c | |||
| db343a6a94 | |||
| 934563d068 | |||
| 91381c7e85 | |||
| 0f59c06093 | |||
| 58b08b3b24 | |||
| 4d1958a9d8 | |||
| 96bdb234ef | |||
| 066be7fe45 | |||
| eb56a2c6fc | |||
| 093664757f | |||
| 479b813531 | |||
| 8673d4433b | |||
| 90443645b2 | |||
| 52a65fcd5a | |||
| f7ce52e212 | |||
| ddb833e9a1 | |||
| 92887e84f0 | |||
| 56c52521ee | |||
| 088a481432 | |||
| d9c23757d1 | |||
| a768fc37cf | |||
| b1ff105e8c | |||
| 4dc965aebf | |||
| 71a9125924 | |||
| 7fc0014d66 | |||
| 8e52267acb | |||
| b80e944665 | |||
| 60bae05d0f | |||
| a19cd375d4 | |||
| 8eb9f88c3b | |||
| 8246985ad8 | |||
| 63cae12bef | |||
| 3dfa3fd5cd | |||
| af9c577bf6 | |||
| 63eed30e78 | |||
| 54d29cdf70 | |||
| 81022eb834 | |||
| 1620e224bf |
@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -11,7 +9,6 @@ from werkzeug.exceptions import BadRequest
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -87,39 +84,3 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -1,17 +1,56 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
@ -19,3 +58,152 @@ class QuotaType(StrEnum):
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
@ -18,13 +18,12 @@ from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
@ -107,7 +106,7 @@ class AppGenerateService:
|
||||
quota_charge = unlimited()
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
||||
except QuotaExceededError:
|
||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||
|
||||
@ -117,7 +116,6 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
|
||||
@ -22,7 +22,6 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -132,10 +131,9 @@ class AsyncWorkflowService:
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Reserve quota (commit after successful dispatch)
|
||||
quota_charge = unlimited()
|
||||
# 7. Check and consume quota
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
||||
except QuotaExceededError as e:
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
@ -155,18 +153,13 @@ class AsyncWorkflowService:
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
try:
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
@ -32,81 +32,6 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class QuotaReserveResult(TypedDict):
|
||||
reservation_id: str
|
||||
available: int
|
||||
reserved: int
|
||||
|
||||
|
||||
class QuotaCommitResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
refunded: int
|
||||
|
||||
|
||||
class QuotaReleaseResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
released: int
|
||||
|
||||
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
@ -119,69 +44,19 @@ class BillingService:
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str) -> BillingInfo:
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return _billing_info_adapter.validate_python(billing_info)
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/quota/info", params=params)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||
) -> QuotaReserveResult:
|
||||
"""Reserve quota before task execution."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"request_id": request_id,
|
||||
"amount": amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_commit(
|
||||
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||
) -> QuotaCommitResult:
|
||||
"""Commit a reservation with actual consumption."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
"actual_amount": actual_amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||
"""Release a reservation (cancel, return frozen quota)."""
|
||||
return _quota_release_adapter.validate_python(
|
||||
cls._send_request(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
@ -281,7 +281,7 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
@ -312,10 +312,7 @@ class FeatureService:
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if "vector_space" in billing_info:
|
||||
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
|
||||
# but LimitationModel.size is int; truncate here for compatibility
|
||||
features.vector_space.size = int(billing_info["vector_space"]["size"])
|
||||
# NOTE END
|
||||
features.vector_space.size = billing_info["vector_space"]["size"]
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
if "documents_upload_quota" in billing_info:
|
||||
@ -336,11 +333,7 @@ class FeatureService:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
if "knowledge_rate_limit" in billing_info:
|
||||
# NOTE (hj24):
|
||||
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
|
||||
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
# NOTE END
|
||||
|
||||
if "knowledge_pipeline_publish_enabled" in billing_info:
|
||||
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
|
||||
|
||||
@ -1,233 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota reservation (Reserve phase).
|
||||
|
||||
Lifecycle:
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||
try:
|
||||
do_work()
|
||||
charge.commit() # Confirm consumption
|
||||
except:
|
||||
charge.refund() # Release frozen quota
|
||||
|
||||
If neither commit() nor refund() is called, the billing system's
|
||||
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None # reservation_id
|
||||
_quota_type: QuotaType
|
||||
_tenant_id: str | None = None
|
||||
_feature_key: str | None = None
|
||||
_amount: int = 0
|
||||
_committed: bool = field(default=False, repr=False)
|
||||
|
||||
def commit(self, actual_amount: int | None = None) -> None:
|
||||
"""
|
||||
Confirm the consumption with actual amount.
|
||||
|
||||
Args:
|
||||
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||
If less than reserved, the difference is refunded automatically.
|
||||
"""
|
||||
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
amount = actual_amount if actual_amount is not None else self._amount
|
||||
BillingService.quota_commit(
|
||||
tenant_id=self._tenant_id,
|
||||
feature_key=self._feature_key,
|
||||
reservation_id=self.charge_id,
|
||||
actual_amount=amount,
|
||||
)
|
||||
self._committed = True
|
||||
logger.debug(
|
||||
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||
self._quota_type,
|
||||
self._tenant_id,
|
||||
self.charge_id,
|
||||
amount,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Release the reserved quota (cancel the charge).
|
||||
|
||||
Safe to call even if:
|
||||
- charge failed or was disabled (charge_id is None)
|
||||
- already committed (Release after Commit is a no-op)
|
||||
- already refunded (idempotent)
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||
|
||||
@staticmethod
|
||||
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve + immediate Commit (one-shot mode).
|
||||
|
||||
The returned QuotaCharge supports .refund() which calls Release.
|
||||
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||
"""
|
||||
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||
if charge.success and charge.charge_id:
|
||||
charge.commit()
|
||||
return charge
|
||||
|
||||
@staticmethod
|
||||
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve quota before task execution (Reserve phase only).
|
||||
|
||||
The caller MUST call charge.commit() after the task succeeds,
|
||||
or charge.refund() if the task fails.
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||
|
||||
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to reserve must be greater than 0")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
feature_key = quota_type.billing_key
|
||||
|
||||
try:
|
||||
reserve_resp = BillingService.quota_reserve(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
request_id=request_id,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
reservation_id = reserve_resp.get("reservation_id")
|
||||
if not reservation_id:
|
||||
logger.warning(
|
||||
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||
tenant_id,
|
||||
quota_type.value,
|
||||
reserve_resp,
|
||||
)
|
||||
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
logger.debug(
|
||||
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||
amount,
|
||||
quota_type.value,
|
||||
tenant_id,
|
||||
reservation_id,
|
||||
)
|
||||
return QuotaCharge(
|
||||
success=True,
|
||||
charge_id=reservation_id,
|
||||
_quota_type=quota_type,
|
||||
_tenant_id=tenant_id,
|
||||
_feature_key=feature_key,
|
||||
_amount=amount,
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return unlimited()
|
||||
|
||||
@staticmethod
|
||||
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||
"""Release a reservation. Guarantees no exceptions."""
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not reservation_id:
|
||||
return
|
||||
|
||||
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||
BillingService.quota_release(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
reservation_id=reservation_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||
|
||||
@staticmethod
|
||||
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_quota_info(tenant_id)
|
||||
if isinstance(usage_info, dict):
|
||||
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||
if isinstance(feature_info, dict):
|
||||
limit = feature_info.get("limit", 0)
|
||||
usage = feature_info.get("usage", 0)
|
||||
if limit == -1:
|
||||
return -1
|
||||
return max(0, limit - usage)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return -1
|
||||
@ -38,7 +38,6 @@ from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
@ -783,9 +782,9 @@ class WebhookService:
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# reserve quota before triggering workflow execution
|
||||
# consume quota before triggering workflow execution
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
@ -796,16 +795,11 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
try:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
@ -42,7 +42,6 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
@ -299,10 +298,10 @@ def dispatch_triggered_workflow(
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# reserve quota before invoking trigger
|
||||
# consume quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
@ -388,7 +387,6 @@ def dispatch_triggered_workflow(
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
quota_charge.commit()
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
|
||||
@ -8,11 +8,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
@ -44,7 +43,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
@ -62,7 +61,6 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
|
||||
@ -36,19 +36,12 @@ class TestAppGenerateService:
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
"reservation_id": "test-reservation-id",
|
||||
"available": 100,
|
||||
"reserved": 1,
|
||||
}
|
||||
mock_billing_service.quota_commit.return_value = {
|
||||
"available": 99,
|
||||
"reserved": 0,
|
||||
"refunded": 0,
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
@ -108,8 +101,6 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_quota_dify_config.BILLING_ENABLED = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
@ -127,7 +118,6 @@ class TestAppGenerateService:
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"quota_dify_config": mock_quota_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
@ -475,7 +465,6 @@ class TestAppGenerateService:
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
@ -489,10 +478,8 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing two-phase quota (reserve + commit)
|
||||
billing = mock_external_service_dependencies["billing_service"]
|
||||
billing.quota_reserve.assert_called_once()
|
||||
billing.quota_commit.assert_called_once()
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from services import quota_service
|
||||
from enums import quota_type
|
||||
|
||||
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
@ -1,349 +0,0 @@
|
||||
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from services.quota_service import QuotaCharge, QuotaService, unlimited
|
||||
|
||||
|
||||
class TestQuotaType:
|
||||
def test_billing_key_trigger(self):
|
||||
assert QuotaType.TRIGGER.billing_key == "trigger_event"
|
||||
|
||||
def test_billing_key_workflow(self):
|
||||
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
|
||||
|
||||
def test_billing_key_unlimited_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid quota type"):
|
||||
_ = QuotaType.UNLIMITED.billing_key
|
||||
|
||||
|
||||
class TestQuotaService:
|
||||
def test_reserve_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService"),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_reserve_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_reserve_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
|
||||
|
||||
assert charge.success is True
|
||||
assert charge.charge_id == "rid-1"
|
||||
assert charge._tenant_id == "t1"
|
||||
assert charge._feature_key == "trigger_event"
|
||||
assert charge._amount == 1
|
||||
mock_bs.quota_reserve.assert_called_once()
|
||||
|
||||
def test_reserve_no_reservation_id_raises(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {}
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_quota_exceeded_propagates(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_api_exception_returns_unlimited(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = RuntimeError("network")
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_consume_calls_reserve_and_commit(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
mock_bs.quota_commit.assert_called_once()
|
||||
|
||||
def test_check_billing_disabled(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_check_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_check_sufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=100),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
|
||||
|
||||
def test_check_insufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=5),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
|
||||
|
||||
def test_check_unlimited_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=-1),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
|
||||
|
||||
def test_check_exception_returns_true(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_release_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_empty_reservation(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.return_value = {}
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_called_once_with(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
|
||||
)
|
||||
|
||||
def test_release_exception_swallowed(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.side_effect = RuntimeError("fail")
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_get_remaining_normal(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
|
||||
|
||||
def test_get_remaining_unlimited(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_over_limit_returns_zero(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_exception_returns_neg1(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.side_effect = RuntimeError
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_empty_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_non_dict_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = "invalid"
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_feature_not_in_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
|
||||
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_non_dict_feature_info(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
|
||||
class TestQuotaCharge:
|
||||
def test_commit_success(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
feature_key="trigger_event",
|
||||
reservation_id="rid-1",
|
||||
actual_amount=1,
|
||||
)
|
||||
assert charge._committed is True
|
||||
|
||||
def test_commit_with_actual_amount(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=10,
|
||||
)
|
||||
charge.commit(actual_amount=5)
|
||||
call_kwargs = mock_bs.quota_commit.call_args[1]
|
||||
assert call_kwargs["actual_amount"] == 5
|
||||
|
||||
def test_commit_idempotent(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
charge.commit()
|
||||
assert mock_bs.quota_commit.call_count == 1
|
||||
|
||||
def test_commit_no_charge_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_no_tenant_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_exception_swallowed(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.side_effect = RuntimeError("fail")
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
|
||||
def test_refund_success(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_refund_no_charge_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
def test_refund_no_tenant_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
|
||||
class TestUnlimited:
|
||||
def test_unlimited_returns_success_with_no_charge_id(self):
|
||||
charge = unlimited()
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
assert charge._quota_type == QuotaType.UNLIMITED
|
||||
@ -23,7 +23,6 @@ import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from enums.quota_type import QuotaType
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@ -448,8 +447,8 @@ class TestGenerateBilling:
|
||||
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
reserve_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
consume_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
@ -468,8 +467,7 @@ class TestGenerateBilling:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
|
||||
quota_charge.commit.assert_called_once()
|
||||
consume_mock.assert_called_once_with("tenant-id")
|
||||
|
||||
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
|
||||
from services.errors.app import QuotaExceededError
|
||||
@ -477,7 +475,7 @@ class TestGenerateBilling:
|
||||
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
|
||||
)
|
||||
|
||||
@ -494,7 +492,7 @@ class TestGenerateBilling:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
|
||||
@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_service: QuotaService mock
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
@ -72,7 +72,7 @@ class TestAsyncWorkflowService:
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_quota_service = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
@ -93,8 +93,8 @@ class TestAsyncWorkflowService:
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaService",
|
||||
new=mock_quota_service,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
@ -107,7 +107,7 @@ class TestAsyncWorkflowService:
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_service": mock_quota_service,
|
||||
"quota_workflow": quota_workflow,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
@ -146,9 +146,6 @@ class TestAsyncWorkflowService:
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
quota_charge_mock = MagicMock()
|
||||
mocks["quota_service"].reserve.return_value = quota_charge_mock
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
@ -166,8 +163,7 @@ class TestAsyncWorkflowService:
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_service"].reserve.assert_called_once()
|
||||
quota_charge_mock.commit.assert_called_once()
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
@ -254,7 +250,7 @@ class TestAsyncWorkflowService:
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
|
||||
@ -290,19 +290,9 @@ class TestBillingServiceSubscriptionInfo:
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "professional", "interval": "month", "education": False},
|
||||
"members": {"size": 1, "limit": 50},
|
||||
"apps": {"size": 1, "limit": 200},
|
||||
"vector_space": {"size": 0.0, "limit": 20480},
|
||||
"knowledge_rate_limit": {"limit": 1000},
|
||||
"documents_upload_quota": {"size": 0, "limit": 1000},
|
||||
"annotation_quota_limit": {"size": 0, "limit": 5000},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": 1775952000,
|
||||
"subscription_plan": "professional",
|
||||
"billing_cycle": "monthly",
|
||||
"status": "active",
|
||||
}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
@ -425,7 +415,7 @@ class TestBillingServiceUsageCalculation:
|
||||
yield mock
|
||||
|
||||
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
|
||||
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
|
||||
"""Test retrieval of tenant feature plan usage information."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
|
||||
@ -438,20 +428,6 @@ class TestBillingServiceUsageCalculation:
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_get_quota_info(self, mock_send_request):
|
||||
"""Test retrieval of quota info from new endpoint."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with positive delta (adding credits)."""
|
||||
# Arrange
|
||||
@ -529,118 +505,6 @@ class TestBillingServiceUsageCalculation:
|
||||
)
|
||||
|
||||
|
||||
class TestBillingServiceQuotaOperations:
|
||||
"""Unit tests for quota reserve/commit/release operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_quota_reserve_success(self, mock_send_request):
|
||||
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/reserve",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
|
||||
)
|
||||
|
||||
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
|
||||
|
||||
assert result["available"] == 99
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["reserved"] == 1
|
||||
assert isinstance(result["reserved"], int)
|
||||
|
||||
def test_quota_reserve_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
|
||||
meta = {"source": "webhook"}
|
||||
|
||||
BillingService.quota_reserve(
|
||||
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"source": "webhook"}
|
||||
|
||||
def test_quota_commit_success(self, mock_send_request):
|
||||
expected = {"available": 98, "reserved": 0, "refunded": 0}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/commit",
|
||||
json={
|
||||
"tenant_id": "t1",
|
||||
"feature_key": "trigger_event",
|
||||
"reservation_id": "rid-1",
|
||||
"actual_amount": 1,
|
||||
},
|
||||
)
|
||||
|
||||
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
|
||||
)
|
||||
|
||||
assert result["available"] == 97
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["refunded"] == 1
|
||||
assert isinstance(result["refunded"], int)
|
||||
|
||||
def test_quota_commit_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
|
||||
meta = {"reason": "partial"}
|
||||
|
||||
BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"reason": "partial"}
|
||||
|
||||
def test_quota_release_success(self, mock_send_request):
|
||||
expected = {"available": 100, "reserved": 0, "released": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
|
||||
)
|
||||
|
||||
def test_quota_release_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
|
||||
|
||||
assert result["available"] == 100
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["released"] == 1
|
||||
assert isinstance(result["released"], int)
|
||||
|
||||
|
||||
class TestBillingServiceRateLimitEnforcement:
|
||||
"""Unit tests for rate limit enforcement mechanisms.
|
||||
|
||||
@ -1131,14 +995,17 @@ class TestBillingServiceEdgeCases:
|
||||
yield mock
|
||||
|
||||
def test_get_info_empty_response(self, mock_send_request):
|
||||
"""Empty response from billing API should raise ValidationError due to missing required fields."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
"""Test handling of empty billing info response."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-empty"
|
||||
mock_send_request.return_value = {}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
BillingService.get_info(tenant_id)
|
||||
# Act
|
||||
result = BillingService.get_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_send_request.assert_called_once()
|
||||
|
||||
def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with zero delta (no change)."""
|
||||
@ -1549,21 +1416,12 @@ class TestBillingServiceIntegrationScenarios:
|
||||
|
||||
# Step 1: Get current billing info
|
||||
mock_send_request.return_value = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "sandbox", "interval": "", "education": False},
|
||||
"members": {"size": 0, "limit": 1},
|
||||
"apps": {"size": 0, "limit": 5},
|
||||
"vector_space": {"size": 0.0, "limit": 50},
|
||||
"knowledge_rate_limit": {"limit": 10},
|
||||
"documents_upload_quota": {"size": 0, "limit": 50},
|
||||
"annotation_quota_limit": {"size": 0, "limit": 10},
|
||||
"docs_processing": "standard",
|
||||
"can_replace_logo": False,
|
||||
"model_load_balancing_enabled": False,
|
||||
"knowledge_pipeline_publish_enabled": False,
|
||||
"subscription_plan": "sandbox",
|
||||
"billing_cycle": "monthly",
|
||||
"status": "active",
|
||||
}
|
||||
current_info = BillingService.get_info(tenant_id)
|
||||
assert current_info["subscription"]["plan"] == "sandbox"
|
||||
assert current_info["subscription_plan"] == "sandbox"
|
||||
|
||||
# Step 2: Get payment link for upgrade
|
||||
mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
|
||||
@ -1677,140 +1535,3 @@ class TestBillingServiceIntegrationScenarios:
|
||||
mock_send_request.return_value = {"result": "success", "activated": True}
|
||||
activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
|
||||
assert activate_result["activated"] is True
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionInfoDataType:
|
||||
"""Unit tests for data type coercion in BillingService.get_info
|
||||
|
||||
1. Verifies the get_info returns correct Python types for numeric fields
|
||||
2. Ensure the compatibility regardless of what results the upstream billing API returns
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def normal_billing_response(self) -> dict:
|
||||
return {
|
||||
"enabled": True,
|
||||
"subscription": {
|
||||
"plan": "team",
|
||||
"interval": "year",
|
||||
"education": False,
|
||||
},
|
||||
"members": {"size": 10, "limit": 50},
|
||||
"apps": {"size": 80, "limit": 200},
|
||||
"vector_space": {"size": 5120.75, "limit": 20480},
|
||||
"knowledge_rate_limit": {"limit": 1000},
|
||||
"documents_upload_quota": {"size": 450, "limit": 1000},
|
||||
"annotation_quota_limit": {"size": 1200, "limit": 5000},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": 1745971200,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def string_billing_response(self) -> dict:
|
||||
return {
|
||||
"enabled": True,
|
||||
"subscription": {
|
||||
"plan": "team",
|
||||
"interval": "year",
|
||||
"education": False,
|
||||
},
|
||||
"members": {"size": "10", "limit": "50"},
|
||||
"apps": {"size": "80", "limit": "200"},
|
||||
"vector_space": {"size": "5120.75", "limit": "20480"},
|
||||
"knowledge_rate_limit": {"limit": "1000"},
|
||||
"documents_upload_quota": {"size": "450", "limit": "1000"},
|
||||
"annotation_quota_limit": {"size": "1200", "limit": "5000"},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": "1745971200",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _assert_billing_info_types(result: dict):
|
||||
assert isinstance(result["enabled"], bool)
|
||||
assert isinstance(result["subscription"]["plan"], str)
|
||||
assert isinstance(result["subscription"]["interval"], str)
|
||||
assert isinstance(result["subscription"]["education"], bool)
|
||||
|
||||
assert isinstance(result["members"]["size"], int)
|
||||
assert isinstance(result["members"]["limit"], int)
|
||||
|
||||
assert isinstance(result["apps"]["size"], int)
|
||||
assert isinstance(result["apps"]["limit"], int)
|
||||
|
||||
assert isinstance(result["vector_space"]["size"], float)
|
||||
assert isinstance(result["vector_space"]["limit"], int)
|
||||
|
||||
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
|
||||
|
||||
assert isinstance(result["documents_upload_quota"]["size"], int)
|
||||
assert isinstance(result["documents_upload_quota"]["limit"], int)
|
||||
|
||||
assert isinstance(result["annotation_quota_limit"]["size"], int)
|
||||
assert isinstance(result["annotation_quota_limit"]["limit"], int)
|
||||
|
||||
assert isinstance(result["docs_processing"], str)
|
||||
assert isinstance(result["can_replace_logo"], bool)
|
||||
assert isinstance(result["model_load_balancing_enabled"], bool)
|
||||
assert isinstance(result["knowledge_pipeline_publish_enabled"], bool)
|
||||
if "next_credit_reset_date" in result:
|
||||
assert isinstance(result["next_credit_reset_date"], int)
|
||||
|
||||
def test_get_info_with_normal_types(self, mock_send_request, normal_billing_response):
|
||||
"""When the billing API returns native numeric types, get_info should preserve them."""
|
||||
mock_send_request.return_value = normal_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
self._assert_billing_info_types(result)
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
|
||||
|
||||
def test_get_info_with_string_types(self, mock_send_request, string_billing_response):
|
||||
"""When the billing API returns numeric values as strings, get_info should coerce them."""
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
self._assert_billing_info_types(result)
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
|
||||
|
||||
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
|
||||
"""NotRequired fields can be absent without raising."""
|
||||
del string_billing_response["next_credit_reset_date"]
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
assert "next_credit_reset_date" not in result
|
||||
self._assert_billing_info_types(result)
|
||||
|
||||
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
|
||||
"""Undefined fields are silently stripped by validate_python."""
|
||||
string_billing_response["new_feature"] = "something"
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
# extra fields are dropped by TypeAdapter on TypedDict
|
||||
assert "new_feature" not in result
|
||||
self._assert_billing_info_types(result)
|
||||
|
||||
def test_get_info_missing_required_field_raises(self, mock_send_request, string_billing_response):
|
||||
"""Missing a required field should raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
del string_billing_response["members"]
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
BillingService.get_info("tenant-type-test")
|
||||
|
||||
@ -55,7 +55,7 @@ describe('DatasetsLayout', () => {
|
||||
setAppContext()
|
||||
})
|
||||
|
||||
it('should render loading when workspace is still loading', () => {
|
||||
it('should keep rendering children when workspace is still loading', () => {
|
||||
setAppContext({
|
||||
isLoadingCurrentWorkspace: true,
|
||||
currentWorkspace: { id: '' },
|
||||
@ -67,8 +67,7 @@ describe('DatasetsLayout', () => {
|
||||
</DatasetsLayout>
|
||||
))
|
||||
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('datasets-content')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('datasets-content')).toBeInTheDocument()
|
||||
expect(mockReplace).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
||||
@ -1,29 +1,18 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect } from 'react'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { ExternalApiPanelProvider } from '@/context/external-api-panel-context'
|
||||
import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { redirect } from '@/next/navigation'
|
||||
|
||||
export default function DatasetsLayout({ children }: { children: React.ReactNode }) {
|
||||
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext()
|
||||
const router = useRouter()
|
||||
const shouldRedirect = !isLoadingCurrentWorkspace
|
||||
&& currentWorkspace.id
|
||||
&& !(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator)
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldRedirect)
|
||||
router.replace('/apps')
|
||||
}, [shouldRedirect, router])
|
||||
|
||||
if (isLoadingCurrentWorkspace || !currentWorkspace.id)
|
||||
return <Loading type="app" />
|
||||
|
||||
if (shouldRedirect) {
|
||||
return null
|
||||
return redirect('/apps')
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import * as React from 'react'
|
||||
import { AppInitializer } from '@/app/components/app-initializer'
|
||||
import InSiteMessageNotification from '@/app/components/app/in-site-message/notification'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
@ -14,7 +13,6 @@ import { EventEmitterContextProvider } from '@/context/event-emitter-provider'
|
||||
import { ModalContextProvider } from '@/context/modal-context-provider'
|
||||
import { ProviderContextProvider } from '@/context/provider-context-provider'
|
||||
import PartnerStack from '../components/billing/partner-stack'
|
||||
import Splash from '../components/splash'
|
||||
import RoleRouteGuard from './role-route-guard'
|
||||
|
||||
const Layout = ({ children }: { children: ReactNode }) => {
|
||||
@ -37,7 +35,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
<PartnerStack />
|
||||
<ReadmePanel />
|
||||
<GotoAnything />
|
||||
<Splash />
|
||||
</ModalContextProvider>
|
||||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
|
||||
@ -41,7 +41,7 @@ describe('RoleRouteGuard', () => {
|
||||
setAppContext()
|
||||
})
|
||||
|
||||
it('should render loading while workspace is loading', () => {
|
||||
it('should hide guarded content while workspace is loading', () => {
|
||||
setAppContext({
|
||||
isLoadingCurrentWorkspace: true,
|
||||
})
|
||||
@ -52,7 +52,6 @@ describe('RoleRouteGuard', () => {
|
||||
</RoleRouteGuard>
|
||||
))
|
||||
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('guarded-content')).not.toBeInTheDocument()
|
||||
expect(mockReplace).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
'use client'
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { usePathname, useRouter } from '@/next/navigation'
|
||||
import { redirect, usePathname } from '@/next/navigation'
|
||||
|
||||
const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const
|
||||
|
||||
@ -13,21 +11,11 @@ const isPathUnderRoute = (pathname: string, route: string) => pathname === route
|
||||
export default function RoleRouteGuard({ children }: { children: ReactNode }) {
|
||||
const { isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
|
||||
const pathname = usePathname()
|
||||
const router = useRouter()
|
||||
const shouldGuardRoute = datasetOperatorRedirectRoutes.some(route => isPathUnderRoute(pathname, route))
|
||||
const shouldRedirect = shouldGuardRoute && !isLoadingCurrentWorkspace && isCurrentWorkspaceDatasetOperator
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldRedirect)
|
||||
router.replace('/datasets')
|
||||
}, [shouldRedirect, router])
|
||||
|
||||
// Block rendering only for guarded routes to avoid permission flicker.
|
||||
if (shouldGuardRoute && isLoadingCurrentWorkspace)
|
||||
return <Loading type="app" />
|
||||
|
||||
if (shouldRedirect)
|
||||
return null
|
||||
return redirect('/datasets')
|
||||
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share'
|
||||
import AuthenticatedLayout from '../authenticated-layout'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockShareCode = 'share-code'
|
||||
const mockUpdateAppInfo = vi.fn()
|
||||
const mockUpdateAppParams = vi.fn()
|
||||
const mockUpdateWebAppMeta = vi.fn()
|
||||
const mockUpdateUserCanAccessApp = vi.fn()
|
||||
|
||||
const mockAppInfo = {
|
||||
app_id: 'app-123',
|
||||
}
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
usePathname: () => '/chat/test-share-code',
|
||||
useSearchParams: () => new URLSearchParams(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/web-app-context', () => ({
|
||||
useWebAppStore: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/access-control', () => ({
|
||||
useGetUserCanAccessApp: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-share', () => ({
|
||||
useGetWebAppInfo: vi.fn(),
|
||||
useGetWebAppParams: vi.fn(),
|
||||
useGetWebAppMeta: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/webapp-auth', () => ({
|
||||
webAppLogout: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('AuthenticatedLayout', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
;(useWebAppStore as unknown as Mock).mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
shareCode: mockShareCode,
|
||||
updateAppInfo: mockUpdateAppInfo,
|
||||
updateAppParams: mockUpdateAppParams,
|
||||
updateWebAppMeta: mockUpdateWebAppMeta,
|
||||
updateUserCanAccessApp: mockUpdateUserCanAccessApp,
|
||||
}
|
||||
return selector(state)
|
||||
})
|
||||
|
||||
;(useGetWebAppInfo as Mock).mockReturnValue({
|
||||
data: mockAppInfo,
|
||||
error: null,
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
;(useGetWebAppParams as Mock).mockReturnValue({
|
||||
data: { user_input_form: [] },
|
||||
error: null,
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
;(useGetWebAppMeta as Mock).mockReturnValue({
|
||||
data: { tool_icons: {} },
|
||||
error: null,
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
;(useGetUserCanAccessApp as Mock).mockReturnValue({
|
||||
data: { result: true },
|
||||
error: null,
|
||||
isPending: false,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Permission Gating', () => {
|
||||
it('should not render children while the app info needed for permission is still pending', () => {
|
||||
;(useGetWebAppInfo as Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
error: null,
|
||||
isPending: true,
|
||||
})
|
||||
|
||||
render(
|
||||
<AuthenticatedLayout>
|
||||
<div>protected child</div>
|
||||
</AuthenticatedLayout>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render children while the access check is still pending', () => {
|
||||
;(useGetUserCanAccessApp as Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
error: null,
|
||||
isPending: true,
|
||||
})
|
||||
|
||||
render(
|
||||
<AuthenticatedLayout>
|
||||
<div>protected child</div>
|
||||
</AuthenticatedLayout>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render children once access is allowed even if metadata queries are still pending', () => {
|
||||
;(useGetWebAppParams as Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
error: null,
|
||||
isPending: true,
|
||||
})
|
||||
|
||||
;(useGetWebAppMeta as Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
error: null,
|
||||
isPending: true,
|
||||
})
|
||||
|
||||
render(
|
||||
<AuthenticatedLayout>
|
||||
<div>protected child</div>
|
||||
</AuthenticatedLayout>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('protected child')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the no permission state when access is denied', () => {
|
||||
;(useGetUserCanAccessApp as Mock).mockReturnValue({
|
||||
data: { result: false },
|
||||
error: null,
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
render(
|
||||
<AuthenticatedLayout>
|
||||
<div>protected child</div>
|
||||
</AuthenticatedLayout>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
|
||||
expect(screen.getByText(/403/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/no permission/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
123
web/app/(shareLayout)/components/__tests__/splash.spec.tsx
Normal file
123
web/app/(shareLayout)/components/__tests__/splash.spec.tsx
Normal file
@ -0,0 +1,123 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import Splash from '../splash'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockWebAppLoginStatus = vi.fn()
|
||||
const mockFetchAccessToken = vi.fn()
|
||||
const mockSetWebAppAccessToken = vi.fn()
|
||||
const mockSetWebAppPassport = vi.fn()
|
||||
const mockWebAppLogout = vi.fn()
|
||||
|
||||
let mockShareCode: string | null = null
|
||||
let mockEmbeddedUserId: string | null = null
|
||||
let mockMessage: string | null = null
|
||||
let mockRedirectUrl: string | null = '/chat/test-share-code'
|
||||
let mockCode: string | null = null
|
||||
let mockTokenFromUrl: string | null = null
|
||||
|
||||
vi.mock('@/context/web-app-context', () => ({
|
||||
useWebAppStore: (selector: (state: { shareCode: string | null, embeddedUserId: string | null }) => unknown) =>
|
||||
selector({
|
||||
shareCode: mockShareCode,
|
||||
embeddedUserId: mockEmbeddedUserId,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => {
|
||||
if (key === 'redirect_url')
|
||||
return mockRedirectUrl
|
||||
if (key === 'message')
|
||||
return mockMessage
|
||||
if (key === 'code')
|
||||
return mockCode
|
||||
if (key === 'web_sso_token')
|
||||
return mockTokenFromUrl
|
||||
return null
|
||||
},
|
||||
toString: () => {
|
||||
const params = new URLSearchParams()
|
||||
if (mockRedirectUrl)
|
||||
params.set('redirect_url', mockRedirectUrl)
|
||||
if (mockMessage)
|
||||
params.set('message', mockMessage)
|
||||
if (mockCode)
|
||||
params.set('code', mockCode)
|
||||
if (mockTokenFromUrl)
|
||||
params.set('web_sso_token', mockTokenFromUrl)
|
||||
return params.toString()
|
||||
},
|
||||
* [Symbol.iterator]() {
|
||||
const params = new URLSearchParams(this.toString())
|
||||
yield* params.entries()
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/share', () => ({
|
||||
fetchAccessToken: (...args: unknown[]) => mockFetchAccessToken(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/webapp-auth', () => ({
|
||||
setWebAppAccessToken: (...args: unknown[]) => mockSetWebAppAccessToken(...args),
|
||||
setWebAppPassport: (...args: unknown[]) => mockSetWebAppPassport(...args),
|
||||
webAppLoginStatus: (...args: unknown[]) => mockWebAppLoginStatus(...args),
|
||||
webAppLogout: (...args: unknown[]) => mockWebAppLogout(...args),
|
||||
}))
|
||||
|
||||
describe('Share Splash', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockShareCode = null
|
||||
mockEmbeddedUserId = null
|
||||
mockMessage = null
|
||||
mockRedirectUrl = '/chat/test-share-code'
|
||||
mockCode = null
|
||||
mockTokenFromUrl = null
|
||||
mockWebAppLoginStatus.mockResolvedValue({
|
||||
userLoggedIn: true,
|
||||
appLoggedIn: true,
|
||||
})
|
||||
mockFetchAccessToken.mockResolvedValue({ access_token: 'token' })
|
||||
})
|
||||
|
||||
describe('Share Code Guard', () => {
|
||||
it('should skip login-status checks until the share code is available', async () => {
|
||||
render(
|
||||
<Splash>
|
||||
<div>share child</div>
|
||||
</Splash>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('share child')).toBeInTheDocument()
|
||||
await waitFor(() => {
|
||||
expect(mockWebAppLoginStatus).not.toHaveBeenCalled()
|
||||
})
|
||||
expect(mockFetchAccessToken).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should resume the auth flow after the share code becomes available', async () => {
|
||||
const { rerender } = render(
|
||||
<Splash>
|
||||
<div>share child</div>
|
||||
</Splash>,
|
||||
)
|
||||
|
||||
mockShareCode = 'share-code'
|
||||
rerender(
|
||||
<Splash>
|
||||
<div>share child</div>
|
||||
</Splash>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockWebAppLoginStatus).toHaveBeenCalledWith('share-code', undefined)
|
||||
})
|
||||
expect(mockReplace).toHaveBeenCalledWith('/chat/test-share-code')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -4,7 +4,6 @@ import * as React from 'react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
@ -18,9 +17,9 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const updateAppParams = useWebAppStore(s => s.updateAppParams)
|
||||
const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta)
|
||||
const updateUserCanAccessApp = useWebAppStore(s => s.updateUserCanAccessApp)
|
||||
const { isFetching: isFetchingAppParams, data: appParams, error: appParamsError } = useGetWebAppParams()
|
||||
const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo()
|
||||
const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta()
|
||||
const { data: appParams, error: appParamsError } = useGetWebAppParams()
|
||||
const { data: appInfo, error: appInfoError } = useGetWebAppInfo()
|
||||
const { data: appMeta, error: appMetaError } = useGetWebAppMeta()
|
||||
const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false })
|
||||
|
||||
useEffect(() => {
|
||||
@ -81,17 +80,11 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-2">
|
||||
<AppUnavailable className="h-auto w-auto" code={403} unknownReason="no permission." />
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Loading />
|
||||
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
'use client'
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
@ -32,11 +31,8 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router, shareCode])
|
||||
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
setIsLoading(false)
|
||||
return
|
||||
}
|
||||
|
||||
@ -46,12 +42,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
const redirectOrFinish = () => {
|
||||
if (redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
else
|
||||
setIsLoading(false)
|
||||
}
|
||||
|
||||
const proceedToAuth = () => {
|
||||
setIsLoading(false)
|
||||
}
|
||||
|
||||
(async () => {
|
||||
@ -60,9 +50,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
if (userLoggedIn && appLoggedIn) {
|
||||
redirectOrFinish()
|
||||
}
|
||||
else if (!userLoggedIn && !appLoggedIn) {
|
||||
proceedToAuth()
|
||||
}
|
||||
else if (!userLoggedIn && appLoggedIn) {
|
||||
redirectOrFinish()
|
||||
}
|
||||
@ -77,7 +64,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
}
|
||||
catch {
|
||||
await webAppLogout(shareCode!)
|
||||
proceedToAuth()
|
||||
}
|
||||
}
|
||||
})()
|
||||
@ -95,18 +81,11 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-4">
|
||||
<AppUnavailable className="h-auto w-auto" code={code || t('common.appUnavailable', { ns: 'share' })} unknownReason={message} />
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
|
||||
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
|
||||
75
web/app/(shareLayout)/webapp-signin/__tests__/page.spec.tsx
Normal file
75
web/app/(shareLayout)/webapp-signin/__tests__/page.spec.tsx
Normal file
@ -0,0 +1,75 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import WebSSOForm from '../page'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
let mockRedirectUrl = '/share/test-share-code'
|
||||
let mockWebAppAccessMode: AccessMode | null = null
|
||||
let mockSystemFeaturesEnabled = true
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => key === 'redirect_url' ? mockRedirectUrl : null,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { webapp_auth: { enabled: boolean } } }) => unknown) =>
|
||||
selector({
|
||||
systemFeatures: {
|
||||
webapp_auth: {
|
||||
enabled: mockSystemFeaturesEnabled,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/web-app-context', () => ({
|
||||
useWebAppStore: (selector: (state: { webAppAccessMode: AccessMode | null, shareCode: string | null }) => unknown) =>
|
||||
selector({
|
||||
webAppAccessMode: mockWebAppAccessMode,
|
||||
shareCode: 'test-share-code',
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/webapp-auth', () => ({
|
||||
webAppLogout: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../normalForm', () => ({
|
||||
default: () => <div data-testid="normal-form" />,
|
||||
}))
|
||||
|
||||
vi.mock('../components/external-member-sso-auth', () => ({
|
||||
default: () => <div data-testid="external-member-sso-auth" />,
|
||||
}))
|
||||
|
||||
describe('WebSSOForm', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockRedirectUrl = '/share/test-share-code'
|
||||
mockWebAppAccessMode = null
|
||||
mockSystemFeaturesEnabled = true
|
||||
})
|
||||
|
||||
describe('Access Mode Resolution', () => {
|
||||
it('should avoid rendering auth variants before the access mode query resolves', () => {
|
||||
render(<WebSSOForm />)
|
||||
|
||||
expect(screen.queryByTestId('normal-form')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('external-member-sso-auth')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('share.login.backToHome')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the normal form for organization-backed access modes', () => {
|
||||
mockWebAppAccessMode = AccessMode.ORGANIZATION
|
||||
|
||||
render(<WebSSOForm />)
|
||||
|
||||
expect(screen.getByTestId('normal-form')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -63,7 +63,7 @@ const WebSSOForm: FC = () => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-4">
|
||||
<AppUnavailable className="h-auto w-auto" isUnknownReason={true} />
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
|
||||
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
'use client'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
||||
import Header from '@/app/signin/_header'
|
||||
import { AppContextProvider } from '@/context/app-context-provider'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
@ -11,16 +9,8 @@ import { cn } from '@/utils/classnames'
|
||||
export default function SignInLayout({ children }: any) {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
useDocumentTitle('')
|
||||
const { isLoading, data: loginData } = useIsLogin()
|
||||
const { data: loginData } = useIsLogin()
|
||||
const isLoggedIn = loginData?.logged_in
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex min-h-screen w-full justify-center bg-background-default-burn">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
|
||||
|
||||
@ -12,7 +12,6 @@ import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
@ -71,6 +70,8 @@ export default function OAuthAuthorize() {
|
||||
const isLoggedIn = loginData?.logged_in
|
||||
const isLoading = isOAuthLoading || isIsLoginLoading
|
||||
const onLoginSwitchClick = () => {
|
||||
if (isLoading)
|
||||
return
|
||||
try {
|
||||
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`)
|
||||
setPostLoginRedirect(returnUrl)
|
||||
@ -106,14 +107,6 @@ export default function OAuthAuthorize() {
|
||||
}
|
||||
}, [client_id, redirect_uri, isError])
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="bg-background-default-subtle">
|
||||
<Loading type="app" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-background-default-subtle">
|
||||
{authAppInfo?.app_icon && (
|
||||
@ -122,13 +115,13 @@ export default function OAuthAuthorize() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={`mb-4 mt-5 flex flex-col gap-2 ${isLoggedIn ? 'pb-2' : ''}`}>
|
||||
<div className={`mt-5 mb-4 flex flex-col gap-2 ${isLoggedIn ? 'pb-2' : ''}`}>
|
||||
<div className="title-4xl-semi-bold">
|
||||
{isLoggedIn && <div className="text-text-primary">{t('connect', { ns: 'oauth' })}</div>}
|
||||
<div className="text-(--color-saas-dify-blue-inverted)">{authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })}</div>
|
||||
{!isLoggedIn && <div className="text-text-primary">{t('tips.notLoggedIn', { ns: 'oauth' })}</div>}
|
||||
</div>
|
||||
<div className="text-text-secondary body-md-regular">{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}</div>
|
||||
<div className="body-md-regular text-text-secondary">{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}</div>
|
||||
</div>
|
||||
|
||||
{isLoggedIn && userProfile && (
|
||||
@ -137,7 +130,7 @@ export default function OAuthAuthorize() {
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
|
||||
<div>
|
||||
<div className="system-md-semi-bold text-text-secondary">{userProfile.name}</div>
|
||||
<div className="text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
<div className="system-xs-regular text-text-tertiary">{userProfile.email}</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant="tertiary" size="small" onClick={onLoginSwitchClick}>{t('switchAccount', { ns: 'oauth' })}</Button>
|
||||
@ -149,7 +142,7 @@ export default function OAuthAuthorize() {
|
||||
{authAppInfo!.scope.split(/\s+/).filter(Boolean).map((scope: string) => {
|
||||
const Icon = SCOPE_INFO_MAP[scope]
|
||||
return (
|
||||
<div key={scope} className="flex items-center gap-2 text-text-secondary body-sm-medium">
|
||||
<div key={scope} className="flex items-center gap-2 body-sm-medium text-text-secondary">
|
||||
{Icon ? <Icon.icon className="h-4 w-4" /> : <RiAccountCircleLine className="h-4 w-4" />}
|
||||
{Icon.label}
|
||||
</div>
|
||||
@ -182,7 +175,7 @@ export default function OAuthAuthorize() {
|
||||
</defs>
|
||||
</svg>
|
||||
</div>
|
||||
<div className="mt-3 text-text-tertiary system-xs-regular">{t('tips.common', { ns: 'oauth' })}</div>
|
||||
<div className="mt-3 system-xs-regular text-text-tertiary">{t('tips.common', { ns: 'oauth' })}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import Cookies from 'js-cookie'
|
||||
import { parseAsBoolean, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
@ -25,7 +25,6 @@ export const AppInitializer = ({
|
||||
const searchParams = useSearchParams()
|
||||
// Tokens are now stored in cookies, no need to check localStorage
|
||||
const pathname = usePathname()
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
@ -87,10 +86,7 @@ export const AppInitializer = ({
|
||||
const redirectUrl = resolvePostLoginRedirect()
|
||||
if (redirectUrl) {
|
||||
location.replace(redirectUrl)
|
||||
return
|
||||
}
|
||||
|
||||
setInit(true)
|
||||
}
|
||||
catch {
|
||||
router.replace('/signin')
|
||||
@ -98,5 +94,5 @@ export const AppInitializer = ({
|
||||
})()
|
||||
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser])
|
||||
|
||||
return init ? children : null
|
||||
return children
|
||||
}
|
||||
|
||||
@ -2,8 +2,6 @@ import { render } from '@testing-library/react'
|
||||
import PartnerStackCookieRecorder from '../cookie-recorder'
|
||||
|
||||
let isCloudEdition = true
|
||||
let psPartnerKey: string | undefined
|
||||
let psClickId: string | undefined
|
||||
|
||||
const saveOrUpdate = vi.fn()
|
||||
|
||||
@ -15,8 +13,6 @@ vi.mock('@/config', () => ({
|
||||
|
||||
vi.mock('../use-ps-info', () => ({
|
||||
default: () => ({
|
||||
psPartnerKey,
|
||||
psClickId,
|
||||
saveOrUpdate,
|
||||
}),
|
||||
}))
|
||||
@ -25,8 +21,6 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isCloudEdition = true
|
||||
psPartnerKey = undefined
|
||||
psClickId = undefined
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
|
||||
@ -48,16 +42,4 @@ describe('PartnerStackCookieRecorder', () => {
|
||||
|
||||
expect(container.innerHTML).toBe('')
|
||||
})
|
||||
|
||||
it('should call saveOrUpdate again when partner stack query changes', () => {
|
||||
const { rerender } = render(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
|
||||
|
||||
psPartnerKey = 'updated-partner'
|
||||
psClickId = 'updated-click'
|
||||
rerender(<PartnerStackCookieRecorder />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStackCookieRecorder = () => {
|
||||
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
|
||||
const { saveOrUpdate } = usePSInfo()
|
||||
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
saveOrUpdate()
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate])
|
||||
}, [])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
const PartnerStack: FC = () => {
|
||||
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
|
||||
const { saveOrUpdate, bind } = usePSInfo()
|
||||
useEffect(() => {
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
|
||||
saveOrUpdate()
|
||||
// bind PartnerStack info after user logged in
|
||||
bind()
|
||||
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
|
||||
}, [])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -24,11 +24,9 @@ const usePSInfo = () => {
|
||||
}] = useBoolean(false)
|
||||
const { mutateAsync } = useBindPartnerStackInfo()
|
||||
// Save to top domain. cloud.dify.ai => .dify.ai
|
||||
const domain = globalThis.location?.hostname.replace('cloud', '')
|
||||
const domain = globalThis.location?.hostname?.replace('cloud', '')
|
||||
|
||||
const saveOrUpdate = useCallback(() => {
|
||||
if (hasBind)
|
||||
return
|
||||
if (!psPartnerKey || !psClickId)
|
||||
return
|
||||
if (!isPSChanged)
|
||||
@ -39,23 +37,11 @@ const usePSInfo = () => {
|
||||
}), {
|
||||
expires: PARTNER_STACK_CONFIG.saveCookieDays,
|
||||
path: '/',
|
||||
domain,
|
||||
...(domain ? { domain } : {}),
|
||||
})
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
|
||||
}, [psPartnerKey, psClickId, isPSChanged, domain])
|
||||
|
||||
const bind = useCallback(async () => {
|
||||
// for debug
|
||||
if (!hasBind)
|
||||
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: "bind",
|
||||
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
|
||||
}),
|
||||
})
|
||||
if (psPartnerKey && psClickId && !hasBind) {
|
||||
let shouldRemoveCookie = false
|
||||
try {
|
||||
@ -69,8 +55,12 @@ const usePSInfo = () => {
|
||||
if ((error as { status: number })?.status === 400)
|
||||
shouldRemoveCookie = true
|
||||
}
|
||||
if (shouldRemoveCookie)
|
||||
Cookies.remove(PARTNER_STACK_CONFIG.cookieName, { path: '/', domain })
|
||||
if (shouldRemoveCookie) {
|
||||
Cookies.remove(PARTNER_STACK_CONFIG.cookieName, {
|
||||
path: '/',
|
||||
...(domain ? { domain } : {}),
|
||||
})
|
||||
}
|
||||
setBind()
|
||||
}
|
||||
}, [psPartnerKey, psClickId, hasBind, domain, setBind, mutateAsync])
|
||||
|
||||
@ -10,6 +10,8 @@ type HeaderWrapperProps = {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
const getWorkflowCanvasMaximize = () => globalThis.localStorage?.getItem('workflow-canvas-maximize') === 'true'
|
||||
|
||||
const HeaderWrapper = ({
|
||||
children,
|
||||
}: HeaderWrapperProps) => {
|
||||
@ -18,8 +20,7 @@ const HeaderWrapper = ({
|
||||
// Check if the current path is a workflow canvas & fullscreen
|
||||
const inWorkflowCanvas = pathname.endsWith('/workflow')
|
||||
const isPipelineCanvas = pathname.endsWith('/pipeline')
|
||||
const workflowCanvasMaximize = localStorage.getItem('workflow-canvas-maximize') === 'true'
|
||||
const [hideHeader, setHideHeader] = useState(workflowCanvasMaximize)
|
||||
const [hideHeader, setHideHeader] = useState(getWorkflowCanvasMaximize)
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
|
||||
eventEmitter?.useSubscription((v: any) => {
|
||||
|
||||
@ -3,16 +3,19 @@ import { X } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { NOTICE_I18N } from '@/i18n-config/language'
|
||||
|
||||
const getShowNotice = () => globalThis.localStorage?.getItem('hide-maintenance-notice') !== '1'
|
||||
|
||||
const MaintenanceNotice = () => {
|
||||
const locale = useLanguage()
|
||||
|
||||
const [showNotice, setShowNotice] = useState(() => localStorage.getItem('hide-maintenance-notice') !== '1')
|
||||
const [showNotice, setShowNotice] = useState(getShowNotice)
|
||||
|
||||
const handleJumpNotice = () => {
|
||||
window.open(NOTICE_I18N.href, '_blank')
|
||||
}
|
||||
|
||||
const handleCloseNotice = () => {
|
||||
localStorage.setItem('hide-maintenance-notice', '1')
|
||||
globalThis.localStorage?.setItem('hide-maintenance-notice', '1')
|
||||
setShowNotice(false)
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import type { SiteInfo } from '@/models/share'
|
||||
import { act, cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import MenuDropdown from '../menu-dropdown'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockPathname = '/test-path'
|
||||
let mockWebAppAccessMode: AccessMode | null = AccessMode.SPECIFIC_GROUPS_MEMBERS
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
@ -16,7 +18,7 @@ const mockShareCode = 'test-share-code'
|
||||
vi.mock('@/context/web-app-context', () => ({
|
||||
useWebAppStore: (selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
webAppAccessMode: 'code',
|
||||
webAppAccessMode: mockWebAppAccessMode,
|
||||
shareCode: mockShareCode,
|
||||
}
|
||||
return selector(state)
|
||||
@ -41,6 +43,7 @@ describe('MenuDropdown', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWebAppAccessMode = AccessMode.SPECIFIC_GROUPS_MEMBERS
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
@ -151,6 +154,19 @@ describe('MenuDropdown', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide logout option when access mode is unknown', async () => {
|
||||
mockWebAppAccessMode = null
|
||||
|
||||
render(<MenuDropdown data={baseSiteInfo} hideLogout={false} />)
|
||||
|
||||
const triggerButton = screen.getByRole('button')
|
||||
fireEvent.click(triggerButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('common.userProfile.logout')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call webAppLogout and redirect when logout is clicked', async () => {
|
||||
render(<MenuDropdown data={baseSiteInfo} hideLogout={false} />)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { useTextGenerationAppState } from '../use-text-generation-app-state'
|
||||
|
||||
@ -124,13 +125,13 @@ const defaultAppParams = {
|
||||
type MockWebAppState = {
|
||||
appInfo: MockAppInfo | null
|
||||
appParams: typeof defaultAppParams | null
|
||||
webAppAccessMode: string
|
||||
webAppAccessMode: AccessMode | null
|
||||
}
|
||||
|
||||
const mockWebAppState: MockWebAppState = {
|
||||
appInfo: defaultAppInfo,
|
||||
appParams: defaultAppParams,
|
||||
webAppAccessMode: 'public',
|
||||
webAppAccessMode: AccessMode.PUBLIC,
|
||||
}
|
||||
|
||||
const resetMockWebAppState = () => {
|
||||
@ -160,7 +161,7 @@ const resetMockWebAppState = () => {
|
||||
image_file_size_limit: 10,
|
||||
},
|
||||
}
|
||||
mockWebAppState.webAppAccessMode = 'public'
|
||||
mockWebAppState.webAppAccessMode = AccessMode.PUBLIC
|
||||
}
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
|
||||
@ -58,6 +58,7 @@ const MenuDropdown: FC<Props> = ({
|
||||
}, [router, pathname, webAppLogout, shareCode])
|
||||
|
||||
const [show, setShow] = useState(false)
|
||||
const showLogout = !hideLogout && webAppAccessMode !== null && webAppAccessMode !== AccessMode.EXTERNAL_MEMBERS && webAppAccessMode !== AccessMode.PUBLIC
|
||||
|
||||
useEffect(() => {
|
||||
if (forceClose)
|
||||
@ -85,7 +86,7 @@ const MenuDropdown: FC<Props> = ({
|
||||
<PortalToFollowElemContent className="z-50">
|
||||
<div className="w-[224px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-xs">
|
||||
<div className="p-1">
|
||||
<div className={cn('system-md-regular flex cursor-pointer items-center rounded-lg py-1.5 pl-3 pr-2 text-text-secondary')}>
|
||||
<div className={cn('flex cursor-pointer items-center rounded-lg py-1.5 pl-3 pr-2 text-text-secondary system-md-regular')}>
|
||||
<div className="grow">{t('theme.theme', { ns: 'common' })}</div>
|
||||
<ThemeSwitcher />
|
||||
</div>
|
||||
@ -93,7 +94,7 @@ const MenuDropdown: FC<Props> = ({
|
||||
<Divider type="horizontal" className="my-0" />
|
||||
<div className="p-1">
|
||||
{data?.privacy_policy && (
|
||||
<a href={data.privacy_policy} target="_blank" className="system-md-regular flex cursor-pointer items-center rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover">
|
||||
<a href={data.privacy_policy} target="_blank" rel="noopener noreferrer" className="flex cursor-pointer items-center rounded-lg px-3 py-1.5 text-text-secondary system-md-regular hover:bg-state-base-hover">
|
||||
<span className="grow">{t('chat.privacyPolicyMiddle', { ns: 'share' })}</span>
|
||||
</a>
|
||||
)}
|
||||
@ -102,16 +103,16 @@ const MenuDropdown: FC<Props> = ({
|
||||
handleTrigger()
|
||||
setShow(true)
|
||||
}}
|
||||
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
|
||||
className="cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary system-md-regular hover:bg-state-base-hover"
|
||||
>
|
||||
{t('userProfile.about', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
{!(hideLogout || webAppAccessMode === AccessMode.EXTERNAL_MEMBERS || webAppAccessMode === AccessMode.PUBLIC) && (
|
||||
{showLogout && (
|
||||
<div className="p-1">
|
||||
<div
|
||||
onClick={handleLogout}
|
||||
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
|
||||
className="cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary system-md-regular hover:bg-state-base-hover"
|
||||
>
|
||||
{t('userProfile.logout', { ns: 'common' })}
|
||||
</div>
|
||||
|
||||
@ -18,7 +18,7 @@ import RunBatch from './run-batch'
|
||||
import RunOnce from './run-once'
|
||||
|
||||
type TextGenerationSidebarProps = {
|
||||
accessMode: AccessMode
|
||||
accessMode: AccessMode | null
|
||||
allTasksRun: boolean
|
||||
currentTab: string
|
||||
customConfig: TextGenerationCustomConfig | null
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
'use client'
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useIsLogin } from '@/service/use-common'
|
||||
import Loading from './base/loading'
|
||||
|
||||
const Splash: FC<PropsWithChildren> = () => {
|
||||
// would auto redirect to signin page if not logged in
|
||||
const { isLoading, data: loginData } = useIsLogin()
|
||||
const isLoggedIn = loginData?.logged_in
|
||||
|
||||
if (isLoading || !isLoggedIn) {
|
||||
return (
|
||||
<div className="fixed inset-0 z-9999999 flex h-full items-center justify-center bg-background-body">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return null
|
||||
}
|
||||
export default React.memo(Splash)
|
||||
@ -46,7 +46,7 @@ const EducationApplyAge = () => {
|
||||
setShowModal(undefined)
|
||||
onPlanInfoChanged()
|
||||
updateEducationStatus()
|
||||
localStorage.removeItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
globalThis.localStorage?.removeItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
router.replace('/')
|
||||
}
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ const useEducationReverifyNotice = ({
|
||||
export const useEducationInit = () => {
|
||||
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
|
||||
const setShowEducationExpireNoticeModal = useModalContextSelector(s => s.setShowEducationExpireNoticeModal)
|
||||
const educationVerifying = localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
const educationVerifying = globalThis.localStorage?.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
const searchParams = useSearchParams()
|
||||
const educationVerifyAction = searchParams.get('action')
|
||||
|
||||
@ -156,7 +156,7 @@ export const useEducationInit = () => {
|
||||
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.BILLING })
|
||||
|
||||
if (educationVerifyAction === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION)
|
||||
localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes')
|
||||
globalThis.localStorage?.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes')
|
||||
}
|
||||
if (educationVerifyAction === EDUCATION_RE_VERIFY_ACTION)
|
||||
handleVerify()
|
||||
|
||||
@ -1,18 +1,5 @@
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Link from '@/next/link'
|
||||
import { redirect } from '@/next/navigation'
|
||||
|
||||
const Home = async () => {
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col justify-center py-12 sm:px-6 lg:px-8">
|
||||
|
||||
<div className="sm:mx-auto sm:w-full sm:max-w-md">
|
||||
<Loading type="area" />
|
||||
<div className="mt-10 text-center">
|
||||
<Link href="/apps">🚀</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
export default async function Home() {
|
||||
redirect('/apps')
|
||||
}
|
||||
|
||||
export default Home
|
||||
|
||||
111
web/app/signin/invite-settings/__tests__/page.spec.tsx
Normal file
111
web/app/signin/invite-settings/__tests__/page.spec.tsx
Normal file
@ -0,0 +1,111 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import InviteSettingsPage from '../page'
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockRefetch = vi.fn()
|
||||
const mockActivateMember = vi.fn()
|
||||
const mockSetLocaleOnClient = vi.fn()
|
||||
const mockResolvePostLoginRedirect = vi.fn()
|
||||
|
||||
let mockInviteToken = 'invite-token'
|
||||
let mockCheckRes: {
|
||||
is_valid: boolean
|
||||
data: {
|
||||
workspace_name: string
|
||||
email: string
|
||||
workspace_id: string
|
||||
}
|
||||
} | undefined
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => key === 'invite_token' ? mockInviteToken : null,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { branding: { enabled: boolean } } }) => unknown) =>
|
||||
selector({
|
||||
systemFeatures: {
|
||||
branding: {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useInvitationCheck: () => ({
|
||||
data: mockCheckRes,
|
||||
refetch: mockRefetch,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
activateMember: (...args: unknown[]) => mockActivateMember(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config', () => ({
|
||||
setLocaleOnClient: (...args: unknown[]) => mockSetLocaleOnClient(...args),
|
||||
}))
|
||||
|
||||
vi.mock('../../utils/post-login-redirect', () => ({
|
||||
resolvePostLoginRedirect: () => mockResolvePostLoginRedirect(),
|
||||
}))
|
||||
|
||||
describe('InviteSettingsPage', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockInviteToken = 'invite-token'
|
||||
mockCheckRes = undefined
|
||||
mockActivateMember.mockResolvedValue({ result: 'success' })
|
||||
mockSetLocaleOnClient.mockResolvedValue(undefined)
|
||||
mockResolvePostLoginRedirect.mockReturnValue('/apps')
|
||||
})
|
||||
|
||||
describe('Activation Gating', () => {
|
||||
it('should not activate when invitation validation is still pending and Enter is pressed', () => {
|
||||
render(<InviteSettingsPage />)
|
||||
|
||||
const nameInput = screen.getByLabelText('login.name')
|
||||
fireEvent.change(nameInput, { target: { value: 'Alice' } })
|
||||
fireEvent.keyDown(nameInput, { key: 'Enter', code: 'Enter', charCode: 13 })
|
||||
|
||||
expect(mockActivateMember).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should activate when invitation validation has succeeded', async () => {
|
||||
mockCheckRes = {
|
||||
is_valid: true,
|
||||
data: {
|
||||
workspace_name: 'Demo Workspace',
|
||||
email: 'alice@example.com',
|
||||
workspace_id: 'workspace-1',
|
||||
},
|
||||
}
|
||||
|
||||
render(<InviteSettingsPage />)
|
||||
|
||||
const nameInput = screen.getByLabelText('login.name')
|
||||
fireEvent.change(nameInput, { target: { value: 'Alice' } })
|
||||
fireEvent.keyDown(nameInput, { key: 'Enter', code: 'Enter', charCode: 13 })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockActivateMember).toHaveBeenCalledWith({
|
||||
url: '/activate',
|
||||
body: {
|
||||
token: 'invite-token',
|
||||
name: 'Alice',
|
||||
interface_language: 'en-US',
|
||||
timezone: expect.any(String),
|
||||
},
|
||||
})
|
||||
})
|
||||
expect(mockSetLocaleOnClient).toHaveBeenCalledWith('en-US', false)
|
||||
expect(mockReplace).toHaveBeenCalledWith('/apps')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -6,7 +6,6 @@ import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { LICENSE_LINK } from '@/constants/link'
|
||||
@ -37,9 +36,12 @@ export default function InviteSettingsPage() {
|
||||
},
|
||||
}
|
||||
const { data: checkRes, refetch: recheck } = useInvitationCheck(checkParams.params, !!token)
|
||||
const canActivate = checkRes?.is_valid === true
|
||||
|
||||
const handleActivate = useCallback(async () => {
|
||||
try {
|
||||
if (!canActivate)
|
||||
return
|
||||
if (!name) {
|
||||
toast.error(t('enterYourName', { ns: 'login' }))
|
||||
return
|
||||
@ -63,11 +65,9 @@ export default function InviteSettingsPage() {
|
||||
catch {
|
||||
recheck()
|
||||
}
|
||||
}, [language, name, recheck, timezone, token, router, t])
|
||||
}, [canActivate, language, name, recheck, timezone, token, router, t])
|
||||
|
||||
if (!checkRes)
|
||||
return <Loading />
|
||||
if (!checkRes.is_valid) {
|
||||
if (checkRes?.is_valid === false) {
|
||||
return (
|
||||
<div className="flex flex-col md:w-[400px]">
|
||||
<div className="mx-auto w-full">
|
||||
@ -107,7 +107,8 @@ export default function InviteSettingsPage() {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
handleActivate()
|
||||
if (canActivate)
|
||||
handleActivate()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@ -147,8 +148,9 @@ export default function InviteSettingsPage() {
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
onClick={handleActivate}
|
||||
disabled={!canActivate}
|
||||
>
|
||||
{`${t('join', { ns: 'login' })} ${checkRes?.data?.workspace_name}`}
|
||||
{`${t('join', { ns: 'login' })} ${checkRes?.data?.workspace_name ?? ''}`}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
74
web/context/__tests__/global-public-context.spec.tsx
Normal file
74
web/context/__tests__/global-public-context.spec.tsx
Normal file
@ -0,0 +1,74 @@
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import GlobalPublicStoreProvider, { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { defaultSystemFeatures } from '@/types/feature'
|
||||
|
||||
const mockSystemFeatures = vi.fn()
|
||||
const mockFetchSetupStatusWithCache = vi.fn()
|
||||
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleClient: {
|
||||
systemFeatures: () => mockSystemFeatures(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/setup-status', () => ({
|
||||
fetchSetupStatusWithCache: () => mockFetchSetupStatusWithCache(),
|
||||
}))
|
||||
|
||||
const createQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const renderProvider = () => {
|
||||
const queryClient = createQueryClient()
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<GlobalPublicStoreProvider>
|
||||
<div>provider child</div>
|
||||
</GlobalPublicStoreProvider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('GlobalPublicStoreProvider', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
useGlobalPublicStore.setState({ systemFeatures: defaultSystemFeatures })
|
||||
mockFetchSetupStatusWithCache.mockResolvedValue({ setup_status: 'finished' })
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render children when system features are still loading', async () => {
|
||||
mockSystemFeatures.mockReturnValue(new Promise(() => {}))
|
||||
|
||||
renderProvider()
|
||||
|
||||
expect(screen.getByText('provider child')).toBeInTheDocument()
|
||||
await waitFor(() => {
|
||||
expect(mockSystemFeatures).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('State Updates', () => {
|
||||
it('should update the public store when system features query succeeds', async () => {
|
||||
mockSystemFeatures.mockResolvedValue({
|
||||
...defaultSystemFeatures,
|
||||
enable_marketplace: true,
|
||||
})
|
||||
|
||||
renderProvider()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(useGlobalPublicStore.getState().systemFeatures.enable_marketplace).toBe(true)
|
||||
})
|
||||
expect(mockFetchSetupStatusWithCache).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
108
web/context/__tests__/web-app-context.spec.tsx
Normal file
108
web/context/__tests__/web-app-context.spec.tsx
Normal file
@ -0,0 +1,108 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
let mockPathname = '/share/test-share-code'
|
||||
let mockRedirectUrl: string | null = null
|
||||
let mockAccessModeResult: { accessMode: AccessMode } | undefined
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
usePathname: () => mockPathname,
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => key === 'redirect_url' ? mockRedirectUrl : null,
|
||||
toString: () => {
|
||||
const params = new URLSearchParams()
|
||||
if (mockRedirectUrl)
|
||||
params.set('redirect_url', mockRedirectUrl)
|
||||
return params.toString()
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/chat/utils', () => ({
|
||||
getProcessedSystemVariablesFromUrlParams: vi.fn().mockResolvedValue({}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-share', () => ({
|
||||
useGetWebAppAccessModeByCode: vi.fn(() => ({
|
||||
data: mockAccessModeResult,
|
||||
})),
|
||||
}))
|
||||
|
||||
const StoreSnapshot = () => {
|
||||
const shareCode = useWebAppStore(s => s.shareCode)
|
||||
const accessMode = useWebAppStore(s => s.webAppAccessMode)
|
||||
|
||||
return (
|
||||
<div>
|
||||
<span data-testid="share-code">{shareCode ?? 'none'}</span>
|
||||
<span data-testid="access-mode">{accessMode ?? 'unknown'}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WebAppStoreProvider', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPathname = '/share/test-share-code'
|
||||
mockRedirectUrl = null
|
||||
mockAccessModeResult = undefined
|
||||
useWebAppStore.setState({
|
||||
shareCode: null,
|
||||
appInfo: null,
|
||||
appParams: null,
|
||||
webAppAccessMode: null,
|
||||
appMeta: null,
|
||||
userCanAccessApp: false,
|
||||
embeddedUserId: null,
|
||||
embeddedConversationId: null,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Access Mode State', () => {
|
||||
it('should keep the access mode unknown until the query resolves', async () => {
|
||||
render(
|
||||
<WebAppStoreProvider>
|
||||
<StoreSnapshot />
|
||||
</WebAppStoreProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('share-code')).toHaveTextContent('test-share-code')
|
||||
})
|
||||
expect(screen.getByTestId('access-mode')).toHaveTextContent('unknown')
|
||||
})
|
||||
|
||||
it('should reset the access mode when the share code changes before the next result arrives', async () => {
|
||||
const { rerender } = render(
|
||||
<WebAppStoreProvider>
|
||||
<StoreSnapshot />
|
||||
</WebAppStoreProvider>,
|
||||
)
|
||||
|
||||
mockAccessModeResult = { accessMode: AccessMode.PUBLIC }
|
||||
rerender(
|
||||
<WebAppStoreProvider>
|
||||
<StoreSnapshot />
|
||||
</WebAppStoreProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('access-mode')).toHaveTextContent(AccessMode.PUBLIC)
|
||||
})
|
||||
|
||||
mockPathname = '/share/next-share-code'
|
||||
mockAccessModeResult = undefined
|
||||
rerender(
|
||||
<WebAppStoreProvider>
|
||||
<StoreSnapshot />
|
||||
</WebAppStoreProvider>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('share-code')).toHaveTextContent('next-share-code')
|
||||
})
|
||||
expect(screen.getByTestId('access-mode')).toHaveTextContent('unknown')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -3,7 +3,6 @@ import type { FC, PropsWithChildren } from 'react'
|
||||
import type { SystemFeatures } from '@/types/feature'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { create } from 'zustand'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import { defaultSystemFeatures } from '@/types/feature'
|
||||
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
|
||||
@ -53,13 +52,11 @@ const GlobalPublicStoreProvider: FC<PropsWithChildren> = ({
|
||||
}) => {
|
||||
// Fetch systemFeatures and setupStatus in parallel to reduce waterfall.
|
||||
// setupStatus is prefetched here and cached in localStorage for AppInitializer.
|
||||
const { isPending } = useSystemFeaturesQuery()
|
||||
useSystemFeaturesQuery()
|
||||
|
||||
// Prefetch setupStatus for AppInitializer (result not needed here)
|
||||
useSetupStatusQuery()
|
||||
|
||||
if (isPending)
|
||||
return <div className="flex h-screen w-screen items-center justify-center"><Loading /></div>
|
||||
return <>{children}</>
|
||||
}
|
||||
export default GlobalPublicStoreProvider
|
||||
|
||||
@ -106,10 +106,10 @@ export const ModalContextProvider = ({
|
||||
|
||||
const [showAnnotationFullModal, setShowAnnotationFullModal] = useState(false)
|
||||
const handleCancelAccountSettingModal = () => {
|
||||
const educationVerifying = localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
const educationVerifying = globalThis.localStorage?.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
|
||||
if (educationVerifying === 'yes')
|
||||
localStorage.removeItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
globalThis.localStorage?.removeItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
|
||||
accountSettingCallbacksRef.current?.onCancelCallback?.()
|
||||
accountSettingCallbacksRef.current = null
|
||||
|
||||
@ -2,15 +2,13 @@
|
||||
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import type { ChatConfig } from '@/app/components/base/chat/types'
|
||||
import type { AccessMode } from '@/models/access-control'
|
||||
import type { AppData, AppMeta } from '@/models/share'
|
||||
import { useEffect } from 'react'
|
||||
import { create } from 'zustand'
|
||||
import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { usePathname, useSearchParams } from '@/next/navigation'
|
||||
import { useGetWebAppAccessModeByCode } from '@/service/use-share'
|
||||
import { useIsSystemFeaturesPending } from './global-public-context'
|
||||
|
||||
type WebAppStore = {
|
||||
shareCode: string | null
|
||||
@ -19,8 +17,8 @@ type WebAppStore = {
|
||||
updateAppInfo: (appInfo: AppData | null) => void
|
||||
appParams: ChatConfig | null
|
||||
updateAppParams: (appParams: ChatConfig | null) => void
|
||||
webAppAccessMode: AccessMode
|
||||
updateWebAppAccessMode: (accessMode: AccessMode) => void
|
||||
webAppAccessMode: AccessMode | null
|
||||
updateWebAppAccessMode: (accessMode: AccessMode | null) => void
|
||||
appMeta: AppMeta | null
|
||||
updateWebAppMeta: (appMeta: AppMeta | null) => void
|
||||
userCanAccessApp: boolean
|
||||
@ -38,8 +36,8 @@ export const useWebAppStore = create<WebAppStore>(set => ({
|
||||
updateAppInfo: (appInfo: AppData | null) => set(() => ({ appInfo })),
|
||||
appParams: null,
|
||||
updateAppParams: (appParams: ChatConfig | null) => set(() => ({ appParams })),
|
||||
webAppAccessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
|
||||
updateWebAppAccessMode: (accessMode: AccessMode) => set(() => ({ webAppAccessMode: accessMode })),
|
||||
webAppAccessMode: null,
|
||||
updateWebAppAccessMode: (accessMode: AccessMode | null) => set(() => ({ webAppAccessMode: accessMode })),
|
||||
appMeta: null,
|
||||
updateWebAppMeta: (appMeta: AppMeta | null) => set(() => ({ appMeta })),
|
||||
userCanAccessApp: false,
|
||||
@ -65,7 +63,6 @@ const getShareCodeFromPathname = (pathname: string): string | null => {
|
||||
}
|
||||
|
||||
const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||
const isGlobalPending = useIsSystemFeaturesPending()
|
||||
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
|
||||
const updateShareCode = useWebAppStore(state => state.updateShareCode)
|
||||
const updateEmbeddedUserId = useWebAppStore(state => state.updateEmbeddedUserId)
|
||||
@ -104,24 +101,13 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||
}
|
||||
}, [searchParamsString, updateEmbeddedUserId, updateEmbeddedConversationId])
|
||||
|
||||
const { isLoading, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
|
||||
const { data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
|
||||
|
||||
useEffect(() => {
|
||||
if (accessModeResult?.accessMode)
|
||||
updateWebAppAccessMode(accessModeResult.accessMode)
|
||||
}, [accessModeResult, updateWebAppAccessMode, shareCode])
|
||||
|
||||
if (isGlobalPending || isLoading) {
|
||||
return (
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<>
|
||||
{children}
|
||||
</>
|
||||
)
|
||||
return <>{children}</>
|
||||
}
|
||||
export default WebAppStoreProvider
|
||||
|
||||
@ -169,19 +169,6 @@
|
||||
"count": 9
|
||||
}
|
||||
},
|
||||
"app/(shareLayout)/components/authenticated-layout.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/(shareLayout)/components/splash.tsx": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/(shareLayout)/webapp-reset-password/check-code/page.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 4
|
||||
@ -233,11 +220,6 @@
|
||||
"count": 13
|
||||
}
|
||||
},
|
||||
"app/(shareLayout)/webapp-signin/page.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/account/(commonLayout)/account-page/AvatarWithEdit.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
@ -307,9 +289,6 @@
|
||||
}
|
||||
},
|
||||
"app/account/oauth/authorize/page.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 5
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
@ -7912,9 +7891,6 @@
|
||||
},
|
||||
"react/set-state-in-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"app/components/share/text-generation/no-data/index.tsx": {
|
||||
@ -7993,11 +7969,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/splash.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/tools/edit-custom-collection-modal/config-credentials.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
|
||||
@ -21,15 +21,6 @@ const nextConfig: NextConfig = {
|
||||
// https://nextjs.org/docs/api-reference/next.config.js/ignoring-typescript-errors
|
||||
ignoreBuildErrors: true,
|
||||
},
|
||||
async redirects() {
|
||||
return [
|
||||
{
|
||||
source: '/',
|
||||
destination: '/apps',
|
||||
permanent: false,
|
||||
},
|
||||
]
|
||||
},
|
||||
output: 'standalone',
|
||||
compiler: {
|
||||
removeConsole: isDev ? false : { exclude: ['warn', 'error'] },
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
export {
|
||||
redirect,
|
||||
useParams,
|
||||
usePathname,
|
||||
useRouter,
|
||||
|
||||
@ -75,6 +75,28 @@ export default defineConfig(({ mode }) => {
|
||||
// SyntaxError: Named export not found. The requested module is a CommonJS module, which may not support all module.exports as named exports
|
||||
noExternal: ['emoji-mart'],
|
||||
},
|
||||
environments: {
|
||||
rsc: {
|
||||
optimizeDeps: {
|
||||
include: [
|
||||
'lamejs',
|
||||
'lamejs/src/js/BitStream',
|
||||
'lamejs/src/js/Lame',
|
||||
'lamejs/src/js/MPEGMode',
|
||||
],
|
||||
},
|
||||
},
|
||||
ssr: {
|
||||
optimizeDeps: {
|
||||
include: [
|
||||
'lamejs',
|
||||
'lamejs/src/js/BitStream',
|
||||
'lamejs/src/js/Lame',
|
||||
'lamejs/src/js/MPEGMode',
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user