From 75c3ef82d99b71ad30a5eed2fb6182185c229dc6 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:51:10 +0100 Subject: [PATCH] refactor: use EnumText for TenantCreditPool.pool_type (#33959) --- api/core/provider_manager.py | 4 ++-- api/models/model.py | 5 ++++- api/services/credit_pool_service.py | 6 +++++- .../services/test_credit_pool_service.py | 9 +++++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..6d2be0ab7a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -918,11 +918,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/models/model.py b/api/models/model.py index 331a5b7d8c..4541a3b23a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -44,6 +44,7 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + ProviderQuotaType, TagType, ) from .provider_ids import GenericProviderID @@ -2491,7 +2492,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 25de0588fa..0f63d98642 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -6,6 +6,7 @@ import pytest from core.errors.error import QuotaExceededError from models import TenantCreditPool +from models.enums import ProviderQuotaType from services.credit_pool_service import CreditPoolService @@ -20,7 +21,7 @@ class TestCreditPoolService: assert isinstance(pool, TenantCreditPool) assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" + assert pool.pool_type == ProviderQuotaType.TRIAL assert pool.quota_used == 0 assert pool.quota_limit > 0 @@ -28,14 +29,14 @@ class TestCreditPoolService: tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) assert result is not None assert result.tenant_id == tenant_id - assert result.pool_type == "trial" + assert result.pool_type == ProviderQuotaType.TRIAL def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): - result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None