fix: enterprise API error handling and license enforcement (#33044)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Xiyuan Chen
2026-03-15 20:59:41 -07:00
committed by GitHub
parent dd39fcd9bc
commit 977ed79ea0
10 changed files with 383 additions and 23 deletions

View File

@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
Covers:
- Default workspace auto-join behavior
- License status caching (get_cached_license_status)
"""
from unittest.mock import patch
@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()

View File

@ -34,7 +34,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
@ -62,7 +61,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
@ -87,7 +85,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()