Compare commits

..

5 Commits

Author SHA1 Message Date
yyh
9311150bd2 Merge branch 'main' into 4-2-no-global-loading 2026-04-02 19:16:40 +08:00
a3386da5d6 ci: Update pyrefly version to 0.59.1 (#34452)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 09:48:46 +00:00
c49201ee28 [autofix.ci] apply automated fixes 2026-04-02 09:40:01 +00:00
99
318a3d0308 refactor(api): tighten login and wrapper typing (#34447) 2026-04-02 09:36:58 +00:00
d13e6901cf refactor: no global loading 2026-04-02 17:36:06 +08:00
45 changed files with 511 additions and 1405 deletions

View File

@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
return wrapper

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import overload
from sqlalchemy import select
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(
view: Callable[..., Any] | None = None,
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@ -68,14 +84,30 @@ def get_app_model(
return decorator(view)
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -1,6 +1,4 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@ -11,7 +9,6 @@ from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -87,39 +84,3 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@ -1,4 +1,5 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite(f):
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -70,7 +71,7 @@ def _api_prerequisite(f):
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -1,9 +1,10 @@
import inspect
import logging
import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Any, cast, overload
from typing import cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor
def validate_dataset_token(
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
positional_parameters = [
parameter
for parameter in inspect.signature(view).parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@wraps(view)
def decorated(*args: object, **kwargs: object) -> R:
api_token = validate_and_get_api_token("dataset")
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
if not dataset_id and args:
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant does not exist.")
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs)
if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated
return view(api_token.tenant_id, *args, **kwargs)
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
return decorated
def validate_and_get_api_token(scope: str | None = None):

View File

@ -1,5 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from flask import Flask
if TYPE_CHECKING:
from extensions.ext_login import DifyLoginManager
class DifyApp(Flask):
pass
"""Flask application type with Dify-specific extension attributes."""
login_manager: DifyLoginManager

View File

@ -1,17 +1,56 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -19,3 +58,152 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -1,7 +1,8 @@
import json
from typing import cast
import flask_login
from flask import Response, request
from flask import Request, Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
type LoginUser = Account | EndUser
class DifyLoginManager(flask_login.LoginManager):
"""Project-specific Flask-Login manager with a stable unauthorized contract.
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
Overriding this method lets callers rely on that narrower return type instead of
Flask-Login's broader callback contract.
"""
def unauthorized(self) -> Response:
"""Return the registered unauthorized handler result as a Flask `Response`."""
return cast(Response, super().unauthorized())
def load_user_from_request_context(self) -> None:
"""Populate Flask-Login's request-local user cache for the current request."""
self._load_user()
login_manager = DifyLoginManager()
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
"""Load user based on the request."""
del request_from_flask_login
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
raise NotFound("End user not found.")
return end_user
return None
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
@login_manager.unauthorized_handler
def unauthorized_handler():
def unauthorized_handler() -> Response:
"""Handle unauthorized requests."""
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
# Flask-Login's callback contract based on this override.
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
@ -123,5 +150,5 @@ def unauthorized_handler():
)
def init_app(app: DifyApp):
def init_app(app: DifyApp) -> None:
login_manager.init_app(app)

View File

@ -2,19 +2,19 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from flask import current_app, g, has_request_context, request
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_login import DifyLoginManager
from libs.token import check_csrf_token
from models import Account
if TYPE_CHECKING:
from flask.typing import ResponseReturnValue
from models.model import EndUser
@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
def _get_login_manager() -> DifyLoginManager:
"""Return the project login manager with Dify's narrowed unauthorized contract."""
app = cast(DifyApp, current_app)
return app.login_manager
def current_account_with_tenant() -> tuple[Account, str]:
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
@ -42,7 +48,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# `DifyLoginManager` guarantees that the registered unauthorized handler
# is surfaced here as a concrete Flask `Response`.
unauthorized_response: Response = _get_login_manager().unauthorized()
return unauthorized_response
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
_get_login_manager().load_user_from_request_context()
return g._login_user

View File

@ -171,7 +171,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.57.1",
"pyrefly>=0.59.1",
]
############################################################

View File

@ -18,13 +18,12 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -107,7 +106,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -117,7 +116,6 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(

View File

@ -22,7 +22,6 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -132,10 +131,9 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
# 7. Check and consume quota
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -155,18 +153,13 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED

View File

@ -2,7 +2,7 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal, NotRequired, TypedDict
from typing import Literal, TypedDict
import httpx
from pydantic import TypeAdapter
@ -32,81 +32,6 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@ -119,69 +44,19 @@ class BillingService:
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str) -> BillingInfo:
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return _billing_info_adapter.validate_python(billing_info)
return billing_info
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
return cls._send_request("GET", "/quota/info", params=params)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

View File

@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
@ -312,10 +312,7 @@ class FeatureService:
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
# but LimitationModel.size is int; truncate here for compatibility
features.vector_space.size = int(billing_info["vector_space"]["size"])
# NOTE END
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
@ -336,11 +333,7 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
# NOTE (hj24):
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
# NOTE END
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]

View File

@ -1,233 +0,0 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -38,7 +38,6 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -783,9 +782,9 @@ class WebhookService:
user_id=None,
)
# reserve quota before triggering workflow execution
# consume quota before triggering workflow execution
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
@ -796,16 +795,11 @@ class WebhookService:
raise
# Trigger workflow execution asynchronously
try:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -42,7 +42,6 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@ -299,10 +298,10 @@ def dispatch_triggered_workflow(
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# reserve quota before invoking trigger
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
@ -388,7 +387,6 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",

View File

@ -8,11 +8,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@ -44,7 +43,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
@ -62,7 +61,6 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()

View File

@ -36,19 +36,12 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
}
# Setup default mock returns for workflow service
@ -108,8 +101,6 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@ -127,7 +118,6 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@ -475,7 +465,6 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@ -489,10 +478,8 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from services import quota_service
from enums import quota_type
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

@ -20,7 +20,7 @@ def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return app

View File

@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return flask_app

View File

@ -1,349 +0,0 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@ -0,0 +1,17 @@
import json
from flask import Response
from extensions.ext_login import unauthorized_handler
def test_unauthorized_handler_returns_json_response() -> None:
response = unauthorized_handler()
assert isinstance(response, Response)
assert response.status_code == 401
assert response.content_type == "application/json"
assert json.loads(response.get_data(as_text=True)) == {
"code": "unauthorized",
"message": "Unauthorized.",
}

View File

@ -2,11 +2,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = LoginManager()
login_manager = DifyLoginManager()
login_manager.init_app(app)
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
login_manager.unauthorized = mocker.Mock(
name="unauthorized",
return_value=Response("Unauthorized", status=401, content_type="application/json"),
)
@login_manager.user_loader
def load_user(_user_id: str):
@ -109,18 +113,43 @@ class TestLoginRequired:
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users are redirected."""
"""Test that missing or unauthenticated users return the manager response."""
resolve_user = resolve_current_user(resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
assert result is login_app.login_manager.unauthorized.return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
mocker: MockerFixture,
) -> None:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
with login_app.test_request_context():
result = protected_view()
assert result is response
assert isinstance(result, Response)
resolve_user.assert_called_once_with()
csrf_check.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
@ -168,10 +197,14 @@ class TestGetUser:
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
def _load_user() -> None:
def load_user_from_request_context() -> None:
g._login_user = mock_user
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
load_user = mocker.patch.object(
login_app.login_manager,
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
with login_app.test_request_context():
user = login_module._get_user()

View File

@ -23,7 +23,6 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@ -448,8 +447,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(
@ -468,8 +467,7 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
consume_mock.assert_called_once_with("tenant-id")
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@ -477,7 +475,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@ -494,7 +492,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(

View File

@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_service: QuotaService mock
- quota_workflow: QuotaType.WORKFLOW
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@ -72,7 +72,7 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
mock_quota_service = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
@ -93,8 +93,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaService",
new=mock_quota_service,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@ -107,7 +107,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_service": mock_quota_service,
"quota_workflow": quota_workflow,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@ -146,9 +146,6 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@ -166,8 +163,7 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
assert session.commit.call_count == 2
created_log = mocks["repo"].create.call_args[0][0]
@ -254,7 +250,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,

View File

@ -290,19 +290,9 @@ class TestBillingServiceSubscriptionInfo:
# Arrange
tenant_id = "tenant-123"
expected_response = {
"enabled": True,
"subscription": {"plan": "professional", "interval": "month", "education": False},
"members": {"size": 1, "limit": 50},
"apps": {"size": 1, "limit": 200},
"vector_space": {"size": 0.0, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 0, "limit": 1000},
"annotation_quota_limit": {"size": 0, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1775952000,
"subscription_plan": "professional",
"billing_cycle": "monthly",
"status": "active",
}
mock_send_request.return_value = expected_response
@ -425,7 +415,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
"""Test retrieval of tenant feature plan usage information."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@ -438,20 +428,6 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@ -529,118 +505,6 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.
@ -1131,14 +995,17 @@ class TestBillingServiceEdgeCases:
yield mock
def test_get_info_empty_response(self, mock_send_request):
"""Empty response from billing API should raise ValidationError due to missing required fields."""
from pydantic import ValidationError
"""Test handling of empty billing info response."""
# Arrange
tenant_id = "tenant-empty"
mock_send_request.return_value = {}
with pytest.raises(ValidationError):
BillingService.get_info(tenant_id)
# Act
result = BillingService.get_info(tenant_id)
# Assert
assert result == {}
mock_send_request.assert_called_once()
def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
"""Test updating tenant feature usage with zero delta (no change)."""
@ -1549,21 +1416,12 @@ class TestBillingServiceIntegrationScenarios:
# Step 1: Get current billing info
mock_send_request.return_value = {
"enabled": True,
"subscription": {"plan": "sandbox", "interval": "", "education": False},
"members": {"size": 0, "limit": 1},
"apps": {"size": 0, "limit": 5},
"vector_space": {"size": 0.0, "limit": 50},
"knowledge_rate_limit": {"limit": 10},
"documents_upload_quota": {"size": 0, "limit": 50},
"annotation_quota_limit": {"size": 0, "limit": 10},
"docs_processing": "standard",
"can_replace_logo": False,
"model_load_balancing_enabled": False,
"knowledge_pipeline_publish_enabled": False,
"subscription_plan": "sandbox",
"billing_cycle": "monthly",
"status": "active",
}
current_info = BillingService.get_info(tenant_id)
assert current_info["subscription"]["plan"] == "sandbox"
assert current_info["subscription_plan"] == "sandbox"
# Step 2: Get payment link for upgrade
mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
@ -1677,140 +1535,3 @@ class TestBillingServiceIntegrationScenarios:
mock_send_request.return_value = {"result": "success", "activated": True}
activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
assert activate_result["activated"] is True
class TestBillingServiceSubscriptionInfoDataType:
"""Unit tests for data type coercion in BillingService.get_info
1. Verifies the get_info returns correct Python types for numeric fields
2. Ensure the compatibility regardless of what results the upstream billing API returns
"""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
@pytest.fixture
def normal_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": 10, "limit": 50},
"apps": {"size": 80, "limit": 200},
"vector_space": {"size": 5120.75, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 450, "limit": 1000},
"annotation_quota_limit": {"size": 1200, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1745971200,
}
@pytest.fixture
def string_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": "10", "limit": "50"},
"apps": {"size": "80", "limit": "200"},
"vector_space": {"size": "5120.75", "limit": "20480"},
"knowledge_rate_limit": {"limit": "1000"},
"documents_upload_quota": {"size": "450", "limit": "1000"},
"annotation_quota_limit": {"size": "1200", "limit": "5000"},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": "1745971200",
}
@staticmethod
def _assert_billing_info_types(result: dict):
assert isinstance(result["enabled"], bool)
assert isinstance(result["subscription"]["plan"], str)
assert isinstance(result["subscription"]["interval"], str)
assert isinstance(result["subscription"]["education"], bool)
assert isinstance(result["members"]["size"], int)
assert isinstance(result["members"]["limit"], int)
assert isinstance(result["apps"]["size"], int)
assert isinstance(result["apps"]["limit"], int)
assert isinstance(result["vector_space"]["size"], float)
assert isinstance(result["vector_space"]["limit"], int)
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
assert isinstance(result["documents_upload_quota"]["size"], int)
assert isinstance(result["documents_upload_quota"]["limit"], int)
assert isinstance(result["annotation_quota_limit"]["size"], int)
assert isinstance(result["annotation_quota_limit"]["limit"], int)
assert isinstance(result["docs_processing"], str)
assert isinstance(result["can_replace_logo"], bool)
assert isinstance(result["model_load_balancing_enabled"], bool)
assert isinstance(result["knowledge_pipeline_publish_enabled"], bool)
if "next_credit_reset_date" in result:
assert isinstance(result["next_credit_reset_date"], int)
def test_get_info_with_normal_types(self, mock_send_request, normal_billing_response):
"""When the billing API returns native numeric types, get_info should preserve them."""
mock_send_request.return_value = normal_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_with_string_types(self, mock_send_request, string_billing_response):
"""When the billing API returns numeric values as strings, get_info should coerce them."""
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
"""NotRequired fields can be absent without raising."""
del string_billing_response["next_credit_reset_date"]
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
assert "next_credit_reset_date" not in result
self._assert_billing_info_types(result)
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
"""Undefined fields are silently stripped by validate_python."""
string_billing_response["new_feature"] = "something"
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
# extra fields are dropped by TypeAdapter on TypedDict
assert "new_feature" not in result
self._assert_billing_info_types(result)
def test_get_info_missing_required_field_raises(self, mock_send_request, string_billing_response):
"""Missing a required field should raise ValidationError."""
from pydantic import ValidationError
del string_billing_response["members"]
mock_send_request.return_value = string_billing_response
with pytest.raises(ValidationError):
BillingService.get_info("tenant-type-test")

40
api/uv.lock generated
View File

@ -53,23 +53,6 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" },
{ url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" },
{ url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" },
{ url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" },
{ url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" },
{ url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" },
{ url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" },
{ url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" },
{ url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" },
{ url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" },
{ url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" },
{ url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" },
{ url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" },
{ url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" },
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" },
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" },
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" },
@ -1586,7 +1569,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.19.1" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.57.1" },
{ name = "pyrefly", specifier = ">=0.59.1" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@ -4839,18 +4822,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.57.1"
version = "0.59.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
]
[[package]]

View File

@ -1,5 +1,4 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import InSiteMessageNotification from '@/app/components/app/in-site-message/notification'
import AmplitudeProvider from '@/app/components/base/amplitude'
@ -14,7 +13,6 @@ import { EventEmitterContextProvider } from '@/context/event-emitter-provider'
import { ModalContextProvider } from '@/context/modal-context-provider'
import { ProviderContextProvider } from '@/context/provider-context-provider'
import PartnerStack from '../components/billing/partner-stack'
import Splash from '../components/splash'
import RoleRouteGuard from './role-route-guard'
const Layout = ({ children }: { children: ReactNode }) => {
@ -37,7 +35,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
<PartnerStack />
<ReadmePanel />
<GotoAnything />
<Splash />
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>

View File

@ -1,10 +1,8 @@
'use client'
import type { ReactNode } from 'react'
import { useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { usePathname, useRouter } from '@/next/navigation'
import { redirect, usePathname } from '@/next/navigation'
const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const
@ -13,21 +11,11 @@ const isPathUnderRoute = (pathname: string, route: string) => pathname === route
export default function RoleRouteGuard({ children }: { children: ReactNode }) {
const { isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const pathname = usePathname()
const router = useRouter()
const shouldGuardRoute = datasetOperatorRedirectRoutes.some(route => isPathUnderRoute(pathname, route))
const shouldRedirect = shouldGuardRoute && !isLoadingCurrentWorkspace && isCurrentWorkspaceDatasetOperator
useEffect(() => {
if (shouldRedirect)
router.replace('/datasets')
}, [shouldRedirect, router])
// Block rendering only for guarded routes to avoid permission flicker.
if (shouldGuardRoute && isLoadingCurrentWorkspace)
return <Loading type="app" />
if (shouldRedirect)
return null
return redirect('/datasets')
return <>{children}</>
}

View File

@ -3,7 +3,7 @@
import type { ReactNode } from 'react'
import Cookies from 'js-cookie'
import { parseAsBoolean, useQueryState } from 'nuqs'
import { useCallback, useEffect, useState } from 'react'
import { useCallback, useEffect } from 'react'
import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
@ -25,7 +25,6 @@ export const AppInitializer = ({
const searchParams = useSearchParams()
// Tokens are now stored in cookies, no need to check localStorage
const pathname = usePathname()
const [init, setInit] = useState(false)
const [oauthNewUser] = useQueryState(
'oauth_new_user',
parseAsBoolean.withOptions({ history: 'replace' }),
@ -87,10 +86,7 @@ export const AppInitializer = ({
const redirectUrl = resolvePostLoginRedirect()
if (redirectUrl) {
location.replace(redirectUrl)
return
}
setInit(true)
}
catch {
router.replace('/signin')
@ -98,5 +94,5 @@ export const AppInitializer = ({
})()
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser])
return init ? children : null
return children
}

View File

@ -2,8 +2,6 @@ import { render } from '@testing-library/react'
import PartnerStackCookieRecorder from '../cookie-recorder'
let isCloudEdition = true
let psPartnerKey: string | undefined
let psClickId: string | undefined
const saveOrUpdate = vi.fn()
@ -15,8 +13,6 @@ vi.mock('@/config', () => ({
vi.mock('../use-ps-info', () => ({
default: () => ({
psPartnerKey,
psClickId,
saveOrUpdate,
}),
}))
@ -25,8 +21,6 @@ describe('PartnerStackCookieRecorder', () => {
beforeEach(() => {
vi.clearAllMocks()
isCloudEdition = true
psPartnerKey = undefined
psClickId = undefined
})
it('should call saveOrUpdate once on mount when running in cloud edition', () => {
@ -48,16 +42,4 @@ describe('PartnerStackCookieRecorder', () => {
expect(container.innerHTML).toBe('')
})
it('should call saveOrUpdate again when partner stack query changes', () => {
const { rerender } = render(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
psPartnerKey = 'updated-partner'
psClickId = 'updated-click'
rerender(<PartnerStackCookieRecorder />)
expect(saveOrUpdate).toHaveBeenCalledTimes(2)
})
})

View File

@ -5,13 +5,13 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStackCookieRecorder = () => {
const { psPartnerKey, psClickId, saveOrUpdate } = usePSInfo()
const { saveOrUpdate } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
saveOrUpdate()
}, [psPartnerKey, psClickId, saveOrUpdate])
}, [])
return null
}

View File

@ -6,7 +6,7 @@ import { IS_CLOUD_EDITION } from '@/config'
import usePSInfo from './use-ps-info'
const PartnerStack: FC = () => {
const { psPartnerKey, psClickId, saveOrUpdate, bind } = usePSInfo()
const { saveOrUpdate, bind } = usePSInfo()
useEffect(() => {
if (!IS_CLOUD_EDITION)
return
@ -14,7 +14,7 @@ const PartnerStack: FC = () => {
saveOrUpdate()
// bind PartnerStack info after user logged in
bind()
}, [psPartnerKey, psClickId, saveOrUpdate, bind])
}, [])
return null
}

View File

@ -27,8 +27,6 @@ const usePSInfo = () => {
const domain = globalThis.location?.hostname.replace('cloud', '')
const saveOrUpdate = useCallback(() => {
if (hasBind)
return
if (!psPartnerKey || !psClickId)
return
if (!isPSChanged)
@ -41,21 +39,9 @@ const usePSInfo = () => {
path: '/',
domain,
})
}, [psPartnerKey, psClickId, isPSChanged, domain, hasBind])
}, [psPartnerKey, psClickId, isPSChanged, domain])
const bind = useCallback(async () => {
// for debug
if (!hasBind)
fetch("https://cloud.dify.dev/console/api/billing/debug/data", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
type: "bind",
data: psPartnerKey ? JSON.stringify({ psPartnerKey, psClickId }) : "",
}),
})
if (psPartnerKey && psClickId && !hasBind) {
let shouldRemoveCookie = false
try {

View File

@ -18,7 +18,7 @@ const HeaderWrapper = ({
// Check if the current path is a workflow canvas & fullscreen
const inWorkflowCanvas = pathname.endsWith('/workflow')
const isPipelineCanvas = pathname.endsWith('/pipeline')
const workflowCanvasMaximize = localStorage.getItem('workflow-canvas-maximize') === 'true'
const workflowCanvasMaximize = typeof localStorage !== 'undefined' && localStorage.getItem('workflow-canvas-maximize') === 'true'
const [hideHeader, setHideHeader] = useState(workflowCanvasMaximize)
const { eventEmitter } = useEventEmitterContextContext()
@ -28,7 +28,7 @@ const HeaderWrapper = ({
})
return (
<div className={cn('sticky left-0 right-0 top-0 z-30 flex min-h-[56px] shrink-0 grow-0 basis-auto flex-col', s.header, isBordered ? 'border-b border-divider-regular' : '', hideHeader && (inWorkflowCanvas || isPipelineCanvas) && 'hidden')}>
<div className={cn('sticky top-0 right-0 left-0 z-30 flex min-h-[56px] shrink-0 grow-0 basis-auto flex-col', s.header, isBordered ? 'border-b border-divider-regular' : '', hideHeader && (inWorkflowCanvas || isPipelineCanvas) && 'hidden')}>
{children}
</div>
)

View File

@ -1,21 +0,0 @@
'use client'
import type { FC, PropsWithChildren } from 'react'
import * as React from 'react'
import { useIsLogin } from '@/service/use-common'
import Loading from './base/loading'
const Splash: FC<PropsWithChildren> = () => {
// would auto redirect to signin page if not logged in
const { isLoading, data: loginData } = useIsLogin()
const isLoggedIn = loginData?.logged_in
if (isLoading || !isLoggedIn) {
return (
<div className="fixed inset-0 z-9999999 flex h-full items-center justify-center bg-background-body">
<Loading />
</div>
)
}
return null
}
export default React.memo(Splash)

View File

@ -133,7 +133,7 @@ const useEducationReverifyNotice = ({
export const useEducationInit = () => {
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
const setShowEducationExpireNoticeModal = useModalContextSelector(s => s.setShowEducationExpireNoticeModal)
const educationVerifying = localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
const educationVerifying = typeof localStorage !== 'undefined' && localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
const searchParams = useSearchParams()
const educationVerifyAction = searchParams.get('action')

View File

@ -1,18 +1,5 @@
import Loading from '@/app/components/base/loading'
import Link from '@/next/link'
import { redirect } from '@/next/navigation'
const Home = async () => {
return (
<div className="flex min-h-screen flex-col justify-center py-12 sm:px-6 lg:px-8">
<div className="sm:mx-auto sm:w-full sm:max-w-md">
<Loading type="area" />
<div className="mt-10 text-center">
<Link href="/apps">🚀</Link>
</div>
</div>
</div>
)
export default function Home() {
return redirect('/apps')
}
export default Home

View File

@ -3,7 +3,6 @@ import type { FC, PropsWithChildren } from 'react'
import type { SystemFeatures } from '@/types/feature'
import { useQuery } from '@tanstack/react-query'
import { create } from 'zustand'
import Loading from '@/app/components/base/loading'
import { consoleClient } from '@/service/client'
import { defaultSystemFeatures } from '@/types/feature'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
@ -53,13 +52,11 @@ const GlobalPublicStoreProvider: FC<PropsWithChildren> = ({
}) => {
// Fetch systemFeatures and setupStatus in parallel to reduce waterfall.
// setupStatus is prefetched here and cached in localStorage for AppInitializer.
const { isPending } = useSystemFeaturesQuery()
useSystemFeaturesQuery()
// Prefetch setupStatus for AppInitializer (result not needed here)
useSetupStatusQuery()
if (isPending)
return <div className="flex h-screen w-screen items-center justify-center"><Loading /></div>
return <>{children}</>
}
export default GlobalPublicStoreProvider

View File

@ -6698,9 +6698,6 @@
}
},
"app/components/header/header-wrapper.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
@ -7993,11 +7990,6 @@
"count": 1
}
},
"app/components/splash.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1
}
},
"app/components/tools/edit-custom-collection-modal/config-credentials.tsx": {
"no-restricted-imports": {
"count": 1

View File

@ -21,15 +21,6 @@ const nextConfig: NextConfig = {
// https://nextjs.org/docs/api-reference/next.config.js/ignoring-typescript-errors
ignoreBuildErrors: true,
},
async redirects() {
return [
{
source: '/',
destination: '/apps',
permanent: false,
},
]
},
output: 'standalone',
compiler: {
removeConsole: isDev ? false : { exclude: ['warn', 'error'] },

View File

@ -1,4 +1,5 @@
export {
redirect,
useParams,
usePathname,
useRouter,