refactor: use EnumText in provider models (#33634)

This commit is contained in:
tmimmanuel
2026-03-18 04:27:40 +00:00
committed by GitHub
parent 3454224ff9
commit 04c0bf61fa
5 changed files with 56 additions and 43 deletions

View File

@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
from models.enums import CredentialSourceType
from models.provider import ProviderType
from models.provider_ids import ModelProviderID
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-base",
name="LB Base",
credentials={},
credential_source_type="provider",
credential_source_type=CredentialSourceType.PROVIDER,
)
],
),
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
id="lb-custom",
name="LB Custom",
credentials={},
credential_source_type="custom_model",
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
)
],
),
@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1",
credential_record=credential_record,
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)
@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
configuration._update_load_balancing_configs_with_credential(
credential_id="cred-1",
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)

View File

@ -19,6 +19,7 @@ from uuid import uuid4
import pytest
from models.enums import CredentialSourceType, PaymentStatus
from models.provider import (
LoadBalancingModelConfig,
Provider,
@ -158,7 +159,7 @@ class TestProviderModel:
# Assert
assert provider.tenant_id == tenant_id
assert provider.provider_name == provider_name
assert provider.provider_type == "custom"
assert provider.provider_type == ProviderType.CUSTOM
assert provider.is_valid is False
assert provider.quota_used == 0
@ -172,10 +173,10 @@ class TestProviderModel:
provider = Provider(
tenant_id=tenant_id,
provider_name="anthropic",
provider_type="system",
provider_type=ProviderType.SYSTEM,
is_valid=True,
credential_id=credential_id,
quota_type="paid",
quota_type=ProviderQuotaType.PAID,
quota_limit=10000,
quota_used=500,
)
@ -183,10 +184,10 @@ class TestProviderModel:
# Assert
assert provider.tenant_id == tenant_id
assert provider.provider_name == "anthropic"
assert provider.provider_type == "system"
assert provider.provider_type == ProviderType.SYSTEM
assert provider.is_valid is True
assert provider.credential_id == credential_id
assert provider.quota_type == "paid"
assert provider.quota_type == ProviderQuotaType.PAID
assert provider.quota_limit == 10000
assert provider.quota_used == 500
@ -199,7 +200,7 @@ class TestProviderModel:
)
# Assert
assert provider.provider_type == "custom"
assert provider.provider_type == ProviderType.CUSTOM
assert provider.is_valid is False
assert provider.quota_type == ""
assert provider.quota_limit is None
@ -213,7 +214,7 @@ class TestProviderModel:
provider = Provider(
tenant_id=tenant_id,
provider_name="openai",
provider_type="custom",
provider_type=ProviderType.CUSTOM,
)
# Act
@ -253,7 +254,7 @@ class TestProviderModel:
provider = Provider(
tenant_id=str(uuid4()),
provider_name="openai",
provider_type=ProviderType.SYSTEM.value,
provider_type=ProviderType.SYSTEM,
is_valid=True,
)
@ -266,13 +267,13 @@ class TestProviderModel:
provider = Provider(
tenant_id=str(uuid4()),
provider_name="openai",
quota_type="trial",
quota_type=ProviderQuotaType.TRIAL,
quota_limit=1000,
quota_used=250,
)
# Assert
assert provider.quota_type == "trial"
assert provider.quota_type == ProviderQuotaType.TRIAL
assert provider.quota_limit == 1000
assert provider.quota_used == 250
remaining = provider.quota_limit - provider.quota_used
@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider:
preferred = TenantPreferredModelProvider(
tenant_id=tenant_id,
provider_name="openai",
preferred_provider_type="custom",
preferred_provider_type=ProviderType.CUSTOM,
)
# Assert
assert preferred.tenant_id == tenant_id
assert preferred.provider_name == "openai"
assert preferred.preferred_provider_type == "custom"
assert preferred.preferred_provider_type == ProviderType.CUSTOM
def test_tenant_preferred_provider_system_type(self):
"""Test tenant preferred provider with system type."""
@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider:
preferred = TenantPreferredModelProvider(
tenant_id=str(uuid4()),
provider_name="anthropic",
preferred_provider_type="system",
preferred_provider_type=ProviderType.SYSTEM,
)
# Assert
assert preferred.preferred_provider_type == "system"
assert preferred.preferred_provider_type == ProviderType.SYSTEM
class TestProviderOrder:
@ -470,7 +471,7 @@ class TestProviderOrder:
quantity=1,
currency=None,
total_amount=None,
payment_status="wait_pay",
payment_status=PaymentStatus.WAIT_PAY,
paid_at=None,
pay_failed_at=None,
refunded_at=None,
@ -481,7 +482,7 @@ class TestProviderOrder:
assert order.provider_name == "openai"
assert order.account_id == account_id
assert order.payment_product_id == "prod_123"
assert order.payment_status == "wait_pay"
assert order.payment_status == PaymentStatus.WAIT_PAY
assert order.quantity == 1
def test_provider_order_with_payment_details(self):
@ -502,7 +503,7 @@ class TestProviderOrder:
quantity=5,
currency="USD",
total_amount=9999,
payment_status="paid",
payment_status=PaymentStatus.PAID,
paid_at=paid_time,
pay_failed_at=None,
refunded_at=None,
@ -514,7 +515,7 @@ class TestProviderOrder:
assert order.quantity == 5
assert order.currency == "USD"
assert order.total_amount == 9999
assert order.payment_status == "paid"
assert order.payment_status == PaymentStatus.PAID
assert order.paid_at == paid_time
def test_provider_order_payment_statuses(self):
@ -536,23 +537,23 @@ class TestProviderOrder:
}
# Act & Assert - Wait pay status
wait_order = ProviderOrder(**base_params, payment_status="wait_pay")
assert wait_order.payment_status == "wait_pay"
wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY)
assert wait_order.payment_status == PaymentStatus.WAIT_PAY
# Act & Assert - Paid status
paid_order = ProviderOrder(**base_params, payment_status="paid")
assert paid_order.payment_status == "paid"
paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID)
assert paid_order.payment_status == PaymentStatus.PAID
# Act & Assert - Failed status
failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)}
failed_order = ProviderOrder(**failed_params, payment_status="failed")
assert failed_order.payment_status == "failed"
failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED)
assert failed_order.payment_status == PaymentStatus.FAILED
assert failed_order.pay_failed_at is not None
# Act & Assert - Refunded status
refunded_params = {**base_params, "refunded_at": datetime.now(UTC)}
refunded_order = ProviderOrder(**refunded_params, payment_status="refunded")
assert refunded_order.payment_status == "refunded"
refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED)
assert refunded_order.payment_status == PaymentStatus.REFUNDED
assert refunded_order.refunded_at is not None
@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig:
name="Secondary API Key",
encrypted_config='{"api_key": "encrypted_value"}',
credential_id=credential_id,
credential_source_type="custom",
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
)
# Assert
assert config.encrypted_config == '{"api_key": "encrypted_value"}'
assert config.credential_id == credential_id
assert config.credential_source_type == "custom"
assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL
def test_load_balancing_config_disabled(self):
"""Test disabled load balancing config."""