fix: address PR review feedback on enterprise license enforcement

- Cache invalid license statuses with 30s TTL to prevent DoS amplification
- Return LicenseStatus enum (not raw str) from get_cached_license_status
- Flatten nested try/except into _read_cached_license_status / _fetch_and_cache_license_status helpers
- Escalate log levels from debug to warning with exc_info for cache failures
- Switch before_request license check from fail-open to fail-closed
- Remove dead raise_for_status parameter from BaseRequest.send_request
- Gate license expired_at behind is_authenticated; only expose status to unauthenticated callers (CVE-2025-63387)
- Remove redundant 'not is_console_api' guard in before_request
- Add 8 unit tests for get_cached_license_status
This commit is contained in:
GareArc
2026-03-08 17:00:12 -07:00
parent de72bdef71
commit 41af72449d
5 changed files with 210 additions and 53 deletions

View File

@ -40,7 +40,7 @@ def create_flask_app_with_configs() -> DifyApp:
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/") and not is_console_api
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
@ -56,7 +56,7 @@ def create_flask_app_with_configs() -> DifyApp:
if not is_exempt:
try:
# Check license status with caching (10 min TTL)
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
@ -65,7 +65,13 @@ def create_flask_app_with_configs() -> DifyApp:
except UnauthorizedAndForceLogout:
raise
except Exception:
# Fail-closed: if we cannot verify the license (Redis down +
# enterprise API unreachable), block the request. An unreachable
# sidecar is itself an abnormal state that should surface.
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context

View File

@ -48,7 +48,6 @@ class BaseRequest:
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@ -72,14 +71,9 @@ class BaseRequest:
response = client.request(method, url, **request_kwargs)
# Always validate HTTP status and raise domain-specific errors
# Validate HTTP status and raise domain-specific errors
if not response.is_success:
cls._handle_error_response(response)
# Legacy support: still respect raise_for_status parameter
if raise_for_status:
response.raise_for_status()
return response.json()
@classmethod

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -8,12 +11,16 @@ from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
# License status cache configuration
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
LICENSE_STATUS_CACHE_TTL = 600 # 10 minutes
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppSettings(BaseModel):
@ -56,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
@ -119,7 +126,6 @@ class EnterpriseService:
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
@ -229,45 +235,62 @@ class EnterpriseService:
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls):
"""
Get enterprise license status with Redis caching to reduce HTTP calls.
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Only caches valid statuses (active/expiring) since invalid statuses
should be re-checked every request — the admin may update the license
at any time.
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
(inactive/expired/lost) for 1 minute. The shorter TTL for invalid statuses
balances prompt license-fix detection against DoS mitigation — without
caching, every request on an expired license would hit the enterprise API.
Returns license status string or None if unavailable.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
# Try cache first — only valid statuses are cached
try:
cached_status = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if cached_status:
if isinstance(cached_status, bytes):
cached_status = cached_status.decode("utf-8")
return cached_status
except Exception:
logger.debug("Failed to get license status from cache, calling enterprise API")
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.warning("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result."""
from services.feature_service import LicenseStatus
# Cache miss or failure — call enterprise API
try:
info = cls.get_info()
license_info = info.get("License")
if license_info:
from services.feature_service import LicenseStatus
if not license_info:
return None
status = license_info.get("status", LicenseStatus.INACTIVE)
# Only cache valid statuses so license updates are picked up immediately
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING):
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, LICENSE_STATUS_CACHE_TTL, status)
except Exception:
logger.debug("Failed to cache license status")
return status
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
ttl = (
VALID_LICENSE_CACHE_TTL
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
else INVALID_LICENSE_CACHE_TTL
)
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
except Exception:
logger.warning("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.exception("Failed to get enterprise license status")
return None

View File

@ -379,17 +379,19 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
# License status and expiry are always exposed so the login page can
# show the expiry UI after a force-logout (the user is unauthenticated
# at that point). Workspace usage details remain auth-gated.
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
# so the login page can detect an expired/inactive license after force-logout.
# All other license details (expiry date, workspace usage) remain auth-gated.
# See CVE-2025-63387 for prior information-leakage context.
if license_info := enterprise_info.get("License"):
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
features.license.expired_at = license_info.get("expiredAt", "")
if is_authenticated and (workspaces_info := license_info.get("workspaces")):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if is_authenticated:
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]

View File

@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
Covers:
- Default workspace auto-join behavior
- License status caching (get_cached_license_status)
"""
from unittest.mock import patch
@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()