mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
add paid credit
This commit is contained in:
@ -7,6 +7,7 @@ from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,62 +24,27 @@ class CreditPoolService:
|
||||
return credit_pool
|
||||
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]:
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""get or create credit pool"""
|
||||
# First try to get existing pool
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
# Create new pool if not exists, handle race condition
|
||||
try:
|
||||
# Double-check in case another thread created it
|
||||
pool = (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
# Create new pool
|
||||
pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
)
|
||||
db.session.add(pool)
|
||||
db.session.commit()
|
||||
|
||||
except Exception:
|
||||
# If creation fails (e.g., due to race condition), rollback and try to get existing one
|
||||
db.session.rollback()
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if not pool:
|
||||
raise
|
||||
|
||||
return pool
|
||||
|
||||
@classmethod
|
||||
def check_and_deduct_credits(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
):
|
||||
"""check and deduct credits"""
|
||||
logger.info("check and deduct credits")
|
||||
pool = cls.get_pool(tenant_id)
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
@ -86,24 +52,17 @@ class CreditPoolService:
|
||||
raise QuotaExceededError(
|
||||
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
with db.session.begin():
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
where_conditions = [
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
db.session.execute(stmt)
|
||||
|
||||
@classmethod
|
||||
def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool:
|
||||
"""check and deduct credits"""
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if not pool:
|
||||
return False
|
||||
|
||||
if pool.remaining_credits < credits_required:
|
||||
return False
|
||||
return True
|
||||
where_conditions = [
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@ -49,8 +49,14 @@ class WorkspaceService:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id)
|
||||
tenant_info["trial_credits"] = pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = pool.quota_used
|
||||
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
|
||||
if paid_pool:
|
||||
tenant_info["trial_credits"] = paid_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = paid_pool.quota_used
|
||||
else:
|
||||
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
|
||||
if trial_pool:
|
||||
tenant_info["trial_credits"] = trial_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = trial_pool.quota_used
|
||||
|
||||
return tenant_info
|
||||
|
||||
Reference in New Issue
Block a user