mirror of
https://github.com/langgenius/dify.git
synced 2026-05-18 07:56:36 +08:00
Compare commits
36 Commits
hotfix/1.1
...
test/workf
| Author | SHA1 | Date | |
|---|---|---|---|
| f7ae14d50e | |||
| 8e093a2b25 | |||
| 95bdeb674f | |||
| b8b0422e73 | |||
| d5870d2620 | |||
| 3200c574a6 | |||
| ba0c911011 | |||
| 77e7f0a7de | |||
| 74e5ac4153 | |||
| c494f80452 | |||
| fb4d51e750 | |||
| a119726469 | |||
| 8876f69c24 | |||
| 1f8bf024e7 | |||
| 86cda295f0 | |||
| a1e6c3ee77 | |||
| b45f056f62 | |||
| b88e4a5e9c | |||
| 42a644cedb | |||
| 0c01931bcc | |||
| 412eb15527 | |||
| d2deab61d1 | |||
| 430fb1ee97 | |||
| 2b0b3e3321 | |||
| 9f0fcdd049 | |||
| 727fc057d3 | |||
| 3d10cf97f1 | |||
| 373f8245af | |||
| 168ba4caa3 | |||
| 688ccb5aa9 | |||
| af143312f2 | |||
| 3c58c68b8d | |||
| 20f901223b | |||
| 16e8bf1cf9 | |||
| 1943785c1c | |||
| 6633f5aef8 |
@ -4,6 +4,7 @@ import urllib.parse
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -179,7 +180,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -607,15 +607,19 @@ class PublishedRagPipelineApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
session.add(pipeline)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
||||
@ -16,14 +16,12 @@ api = ExternalApi(
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail as _mail
|
||||
from .app import dsl as _app_dsl
|
||||
from .plugin import plugin as _plugin
|
||||
from .workspace import workspace as _workspace
|
||||
|
||||
api.add_namespace(inner_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_app_dsl",
|
||||
"_mail",
|
||||
"_plugin",
|
||||
"_workspace",
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,110 +0,0 @@
|
||||
"""Inner API endpoints for app DSL import/export.
|
||||
|
||||
Called by the enterprise admin-api service. Import requires ``creator_email``
|
||||
to attribute the created app; workspace/membership validation is done by the
|
||||
Go admin-api caller.
|
||||
"""
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from models.account import AccountStatus
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
|
||||
|
||||
class InnerAppDSLImportPayload(BaseModel):
|
||||
yaml_content: str = Field(description="YAML DSL content")
|
||||
creator_email: str = Field(description="Email of the workspace member who will own the imported app")
|
||||
name: str | None = Field(default=None, description="Override app name from DSL")
|
||||
description: str | None = Field(default=None, description="Override app description from DSL")
|
||||
|
||||
|
||||
register_schema_model(inner_api_ns, InnerAppDSLImportPayload)
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspaces/<string:workspace_id>/dsl/import")
|
||||
class EnterpriseAppDSLImport(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("enterprise_app_dsl_import")
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__])
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Import completed",
|
||||
202: "Import pending (DSL version mismatch requires confirmation)",
|
||||
400: "Import failed (business error)",
|
||||
404: "Creator account not found or inactive",
|
||||
}
|
||||
)
|
||||
def post(self, workspace_id: str):
|
||||
"""Import a DSL into a workspace on behalf of a specified creator."""
|
||||
args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
account = _get_active_account(args.creator_email)
|
||||
if account is None:
|
||||
return {"message": f"account '{args.creator_email}' not found or inactive"}, 404
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=args.yaml_content,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
if result.status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/apps/<string:app_id>/dsl")
|
||||
class EnterpriseAppDSLExport(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc(
|
||||
"enterprise_app_dsl_export",
|
||||
responses={
|
||||
200: "Export successful",
|
||||
404: "App not found",
|
||||
},
|
||||
)
|
||||
def get(self, app_id: str):
|
||||
"""Export an app's DSL as YAML."""
|
||||
include_secret = request.args.get("include_secret", "false").lower() == "true"
|
||||
|
||||
app_model = db.session.query(App).filter_by(id=app_id).first()
|
||||
if not app_model:
|
||||
return {"message": "app not found"}, 404
|
||||
|
||||
data = AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=include_secret,
|
||||
)
|
||||
|
||||
return {"data": data}, 200
|
||||
|
||||
|
||||
def _get_active_account(email: str) -> Account | None:
|
||||
"""Look up an active account by email.
|
||||
|
||||
Workspace membership is already validated by the Go admin-api caller.
|
||||
"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if account is None or account.status != AccountStatus.ACTIVE:
|
||||
return None
|
||||
return account
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
@ -8,8 +7,6 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import httpx
|
||||
import qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from flask import current_app
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel
|
||||
@ -291,27 +288,26 @@ class TidbOnQdrantVector(BaseVector):
|
||||
if not ids:
|
||||
return
|
||||
|
||||
batch_size = 1000
|
||||
for i in range(0, len(ids), batch_size):
|
||||
batch = ids[i : i + batch_size]
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=batch),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@ -420,16 +416,13 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id)
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
@ -440,18 +433,11 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
logger.info(
|
||||
"Assigning idle cluster %s to tenant %s",
|
||||
idle_tidb_auth_binding.cluster_id,
|
||||
dataset.tenant_id,
|
||||
)
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
tidb_auth_binding = idle_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id)
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
dify_config.TIDB_API_URL or "",
|
||||
@ -460,39 +446,21 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
dify_config.TIDB_PRIVATE_KEY or "",
|
||||
dify_config.TIDB_REGION or "",
|
||||
)
|
||||
logger.info(
|
||||
"New cluster created: cluster_id=%s, qdrant_endpoint=%s",
|
||||
new_cluster["cluster_id"],
|
||||
new_cluster.get("qdrant_endpoint"),
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
tidb_auth_binding = new_tidb_auth_binding
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
else:
|
||||
logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id)
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
qdrant_url = (
|
||||
(tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or ""
|
||||
)
|
||||
logger.info(
|
||||
"Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)",
|
||||
qdrant_url,
|
||||
tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None,
|
||||
dify_config.TIDB_ON_QDRANT_URL,
|
||||
)
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -507,7 +475,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=qdrant_url,
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=str(config.root_path),
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
@ -12,50 +11,8 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
|
||||
"""Extract the qdrant endpoint URL from a Get Cluster API response.
|
||||
|
||||
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
|
||||
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
|
||||
"""
|
||||
endpoints = cluster_response.get("endpoints") or {}
|
||||
public = endpoints.get("public") or {}
|
||||
host = public.get("host")
|
||||
if host:
|
||||
return f"https://qdrant-{host}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
|
||||
"""Call Get Cluster API and extract the qdrant endpoint.
|
||||
|
||||
Use ``extract_qdrant_endpoint`` instead when you already have
|
||||
the cluster response to avoid a redundant API call.
|
||||
"""
|
||||
try:
|
||||
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if not cluster_response:
|
||||
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
|
||||
return None
|
||||
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
if qdrant_url:
|
||||
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
|
||||
return qdrant_url
|
||||
logger.warning(
|
||||
"No endpoints.public.host found for cluster %s, response keys: %s",
|
||||
cluster_id,
|
||||
list(cluster_response.keys()),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_tidb_serverless_cluster(
|
||||
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
@ -93,45 +50,26 @@ class TidbService:
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
|
||||
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_id = response_data["clusterId"]
|
||||
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
|
||||
retry_count = 0
|
||||
max_retries = 30
|
||||
while retry_count < max_retries:
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if cluster_response["state"] == "ACTIVE":
|
||||
user_prefix = cluster_response["userPrefix"]
|
||||
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
|
||||
logger.info(
|
||||
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
|
||||
cluster_id,
|
||||
user_prefix,
|
||||
qdrant_endpoint,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"cluster_name": display_name,
|
||||
"account": f"{user_prefix}.root",
|
||||
"password": password,
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
logger.info(
|
||||
"Cluster %s state=%s, retry %d/%d",
|
||||
cluster_id,
|
||||
cluster_response["state"],
|
||||
retry_count + 1,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(30)
|
||||
time.sleep(30) # wait 30 seconds before retrying
|
||||
retry_count += 1
|
||||
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
|
||||
else:
|
||||
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
@ -233,20 +171,8 @@ class TidbService:
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
if not cluster_info.qdrant_endpoint:
|
||||
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
|
||||
item
|
||||
) or TidbService.fetch_qdrant_endpoint(
|
||||
api_url, public_key, private_key, item["clusterId"]
|
||||
)
|
||||
if cluster_info.qdrant_endpoint:
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
else:
|
||||
logger.warning(
|
||||
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
|
||||
item["clusterId"],
|
||||
)
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -304,29 +230,19 @@ class TidbService:
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
cached_password = redis_client.get(cache_key)
|
||||
if not cached_password:
|
||||
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
|
||||
continue
|
||||
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
|
||||
logger.info(
|
||||
"Batch cluster %s: qdrant_endpoint=%s",
|
||||
item["clusterId"],
|
||||
qdrant_endpoint,
|
||||
)
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": cached_password.decode("utf-8"),
|
||||
"qdrant_endpoint": qdrant_endpoint,
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
|
||||
response.raise_for_status()
|
||||
return []
|
||||
|
||||
@ -1,17 +1,56 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
@ -19,3 +58,152 @@ class QuotaType(StrEnum):
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
"""add qdrant_endpoint to tidb_auth_bindings
|
||||
|
||||
Revision ID: 8574b23a38fd
|
||||
Revises: 6b5f9f8b1a2c
|
||||
Create Date: 2026-04-14 15:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8574b23a38fd"
|
||||
down_revision = "6b5f9f8b1a2c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
|
||||
batch_op.drop_column("qdrant_endpoint")
|
||||
@ -1250,7 +1250,6 @@ class TidbAuthBinding(TypeBase):
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -113,7 +113,6 @@ class DataSourceType(StrEnum):
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
LOCAL_FILE = "local_file"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
|
||||
class ProcessRuleMode(StrEnum):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.3"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -57,7 +57,6 @@ def create_clusters(batch_size):
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
|
||||
active=False,
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -18,13 +18,12 @@ from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
@ -89,7 +88,7 @@ class AppGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@ -107,7 +106,7 @@ class AppGenerateService:
|
||||
quota_charge = unlimited()
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
||||
except QuotaExceededError:
|
||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||
|
||||
@ -117,150 +116,139 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
)
|
||||
match effective_mode:
|
||||
case AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.AGENT_CHAT:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
case AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
if streaming:
|
||||
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
if streaming:
|
||||
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
pause_state_config=pause_config,
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
pause_state_config=pause_config,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
rate_limit.exit(request_id)
|
||||
@ -292,83 +280,53 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
match app_model.mode:
|
||||
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
case AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(
|
||||
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
|
||||
):
|
||||
match app_model.mode:
|
||||
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
case AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping | Generator:
|
||||
) -> Union[Mapping, Generator]:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
|
||||
@ -22,7 +22,6 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -89,10 +88,7 @@ class AsyncWorkflowService:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||
|
||||
# 2. Get workflow
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
|
||||
|
||||
# commit read only session before starting the billig rpc call
|
||||
session.commit()
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||
|
||||
# 3. Get dispatcher based on tenant subscription
|
||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||
@ -135,10 +131,9 @@ class AsyncWorkflowService:
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Reserve quota (commit after successful dispatch)
|
||||
quota_charge = unlimited()
|
||||
# 7. Check and consume quota
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
||||
except QuotaExceededError as e:
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
@ -158,18 +153,13 @@ class AsyncWorkflowService:
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
try:
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
@ -305,21 +295,13 @@ class AsyncWorkflowService:
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow(
|
||||
workflow_service: WorkflowService,
|
||||
app_model: App,
|
||||
workflow_id: str | None = None,
|
||||
session: Session | None = None,
|
||||
) -> Workflow:
|
||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
|
||||
"""
|
||||
Get workflow for the app
|
||||
|
||||
Args:
|
||||
app_model: App model instance
|
||||
workflow_id: Optional specific workflow ID
|
||||
session: Reuse this SQLAlchemy session for the lookup when provided,
|
||||
so the caller's explicit session bears the connection cost
|
||||
instead of Flask's request-scoped ``db.session``.
|
||||
|
||||
Returns:
|
||||
Workflow instance
|
||||
@ -329,12 +311,12 @@ class AsyncWorkflowService:
|
||||
"""
|
||||
if workflow_id:
|
||||
# Get specific published workflow
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||
else:
|
||||
# Get default published workflow
|
||||
workflow = workflow_service.get_published_workflow(app_model, session=session)
|
||||
workflow = workflow_service.get_published_workflow(app_model)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||
|
||||
|
||||
@ -2,11 +2,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -25,147 +26,6 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class QuotaReserveResult(TypedDict):
|
||||
reservation_id: str
|
||||
available: int
|
||||
reserved: int
|
||||
|
||||
|
||||
class QuotaCommitResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
refunded: int
|
||||
|
||||
|
||||
class QuotaReleaseResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
released: int
|
||||
|
||||
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
|
||||
|
||||
class _TenantFeatureQuota(TypedDict):
|
||||
usage: int
|
||||
limit: int
|
||||
reset_date: NotRequired[int]
|
||||
|
||||
|
||||
class TenantFeatureQuotaInfo(TypedDict):
|
||||
"""Response of /quota/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Same convention as BillingInfo: billing may return int fields as str,
|
||||
always keep non-strict mode to auto-coerce.
|
||||
"""
|
||||
|
||||
trigger_event: _TenantFeatureQuota
|
||||
api_rate_limit: _TenantFeatureQuota
|
||||
|
||||
|
||||
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||
|
||||
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class KnowledgeRateLimitDict(TypedDict):
|
||||
limit: int
|
||||
subscription_plan: str
|
||||
|
||||
|
||||
class TenantFeaturePlanUsageDict(TypedDict):
|
||||
result: str
|
||||
history_id: str
|
||||
|
||||
|
||||
class LangContentDict(TypedDict):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
title_pic_url: str
|
||||
|
||||
|
||||
class NotificationDict(TypedDict):
|
||||
notification_id: str
|
||||
contents: dict[str, LangContentDict]
|
||||
frequency: Literal["once", "every_page_load"]
|
||||
|
||||
|
||||
class AccountNotificationDict(TypedDict, total=False):
|
||||
should_show: bool
|
||||
notification: NotificationDict
|
||||
shouldShow: bool
|
||||
notifications: list[dict]
|
||||
|
||||
|
||||
class UpsertNotificationDict(TypedDict):
|
||||
notification_id: str
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsDict(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
class DismissNotificationDict(TypedDict):
|
||||
success: bool
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
@ -178,73 +38,21 @@ class BillingService:
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str) -> BillingInfo:
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return _billing_info_adapter.validate_python(billing_info)
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return _tenant_feature_quota_info_adapter.validate_python(
|
||||
cls._send_request("GET", "/quota/info", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||
) -> QuotaReserveResult:
|
||||
"""Reserve quota before task execution."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"request_id": request_id,
|
||||
"amount": amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_commit(
|
||||
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||
) -> QuotaCommitResult:
|
||||
"""Commit a reservation with actual consumption."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
"actual_amount": actual_amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||
"""Release a reservation (cancel, return frozen quota)."""
|
||||
return _quota_release_adapter.validate_python(
|
||||
cls._send_request(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@ -29,15 +29,14 @@ class CreditPoolService:
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
||||
"""get tenant credit pool"""
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.limit(1)
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_credits_available(
|
||||
|
||||
@ -281,7 +281,7 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
@ -312,10 +312,7 @@ class FeatureService:
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if "vector_space" in billing_info:
|
||||
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
|
||||
# but LimitationModel.size is int; truncate here for compatibility
|
||||
features.vector_space.size = int(billing_info["vector_space"]["size"])
|
||||
# NOTE END
|
||||
features.vector_space.size = billing_info["vector_space"]["size"]
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
if "documents_upload_quota" in billing_info:
|
||||
@ -336,11 +333,7 @@ class FeatureService:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
if "knowledge_rate_limit" in billing_info:
|
||||
# NOTE (hj24):
|
||||
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
|
||||
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
# NOTE END
|
||||
|
||||
if "knowledge_pipeline_publish_enabled" in billing_info:
|
||||
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
|
||||
|
||||
@ -1,233 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota reservation (Reserve phase).
|
||||
|
||||
Lifecycle:
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||
try:
|
||||
do_work()
|
||||
charge.commit() # Confirm consumption
|
||||
except:
|
||||
charge.refund() # Release frozen quota
|
||||
|
||||
If neither commit() nor refund() is called, the billing system's
|
||||
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None # reservation_id
|
||||
_quota_type: QuotaType
|
||||
_tenant_id: str | None = None
|
||||
_feature_key: str | None = None
|
||||
_amount: int = 0
|
||||
_committed: bool = field(default=False, repr=False)
|
||||
|
||||
def commit(self, actual_amount: int | None = None) -> None:
|
||||
"""
|
||||
Confirm the consumption with actual amount.
|
||||
|
||||
Args:
|
||||
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||
If less than reserved, the difference is refunded automatically.
|
||||
"""
|
||||
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
amount = actual_amount if actual_amount is not None else self._amount
|
||||
BillingService.quota_commit(
|
||||
tenant_id=self._tenant_id,
|
||||
feature_key=self._feature_key,
|
||||
reservation_id=self.charge_id,
|
||||
actual_amount=amount,
|
||||
)
|
||||
self._committed = True
|
||||
logger.debug(
|
||||
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||
self._quota_type,
|
||||
self._tenant_id,
|
||||
self.charge_id,
|
||||
amount,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Release the reserved quota (cancel the charge).
|
||||
|
||||
Safe to call even if:
|
||||
- charge failed or was disabled (charge_id is None)
|
||||
- already committed (Release after Commit is a no-op)
|
||||
- already refunded (idempotent)
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||
|
||||
@staticmethod
|
||||
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve + immediate Commit (one-shot mode).
|
||||
|
||||
The returned QuotaCharge supports .refund() which calls Release.
|
||||
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||
"""
|
||||
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||
if charge.success and charge.charge_id:
|
||||
charge.commit()
|
||||
return charge
|
||||
|
||||
@staticmethod
|
||||
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve quota before task execution (Reserve phase only).
|
||||
|
||||
The caller MUST call charge.commit() after the task succeeds,
|
||||
or charge.refund() if the task fails.
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||
|
||||
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to reserve must be greater than 0")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
feature_key = quota_type.billing_key
|
||||
|
||||
try:
|
||||
reserve_resp = BillingService.quota_reserve(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
request_id=request_id,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
reservation_id = reserve_resp.get("reservation_id")
|
||||
if not reservation_id:
|
||||
logger.warning(
|
||||
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||
tenant_id,
|
||||
quota_type.value,
|
||||
reserve_resp,
|
||||
)
|
||||
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
logger.debug(
|
||||
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||
amount,
|
||||
quota_type.value,
|
||||
tenant_id,
|
||||
reservation_id,
|
||||
)
|
||||
return QuotaCharge(
|
||||
success=True,
|
||||
charge_id=reservation_id,
|
||||
_quota_type=quota_type,
|
||||
_tenant_id=tenant_id,
|
||||
_feature_key=feature_key,
|
||||
_amount=amount,
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return unlimited()
|
||||
|
||||
@staticmethod
|
||||
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||
"""Release a reservation. Guarantees no exceptions."""
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not reservation_id:
|
||||
return
|
||||
|
||||
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||
BillingService.quota_release(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
reservation_id=reservation_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||
|
||||
@staticmethod
|
||||
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_quota_info(tenant_id)
|
||||
if isinstance(usage_info, dict):
|
||||
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||
if isinstance(feature_info, dict):
|
||||
limit = feature_info.get("limit", 0)
|
||||
usage = feature_info.get("usage", 0)
|
||||
if limit == -1:
|
||||
return -1
|
||||
return max(0, limit - usage)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return -1
|
||||
@ -37,7 +37,6 @@ from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
@ -759,47 +758,45 @@ class WebhookService:
|
||||
Exception: If workflow execution fails
|
||||
"""
|
||||
try:
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
with Session(db.engine) as session:
|
||||
# Prepare inputs for the webhook node
|
||||
# The webhook node expects webhook_data in the inputs
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id,
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
# Create trigger data
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
|
||||
# trigger_workflow_async need to handle multipe session commits internally
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# consume quota before triggering workflow execution
|
||||
try:
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -132,38 +132,31 @@ class WorkflowService:
|
||||
if workflow_id:
|
||||
return self.get_published_workflow_by_id(app_model, workflow_id)
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(
|
||||
self, app_model: App, workflow_id: str, session: Session | None = None
|
||||
) -> Workflow | None:
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
return None
|
||||
@ -174,27 +167,23 @@ class WorkflowService:
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
|
||||
if not app_model.workflow_id:
|
||||
return None
|
||||
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
@ -156,12 +156,7 @@ def _execute_workflow_common(
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
# NOTE (hj24)
|
||||
# Release the transaction before the blocking generate() call,
|
||||
# otherwise the connection stays "idle in transaction" for hours.
|
||||
session.commit()
|
||||
# NOTE END
|
||||
|
||||
# Execute the workflow with the trigger type
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
|
||||
@ -28,7 +28,7 @@ from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
@ -42,7 +42,6 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
@ -259,58 +258,59 @@ def dispatch_triggered_workflow(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
trigger_entity: TriggerProviderEntity = provider_controller.entity
|
||||
|
||||
# Ensure expire_on_commit is set to False to remain workflows available
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
for plugin_trigger in subscribers:
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
for plugin_trigger in subscribers:
|
||||
# Get workflow from mapping
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
|
||||
return dispatched_count
|
||||
# Find the trigger node in the workflow
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# invoke trigger
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# consume quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
|
||||
)
|
||||
return 0
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
try:
|
||||
invoke_response = TriggerManager.invoke_trigger_event(
|
||||
tenant_id=subscription.tenant_id,
|
||||
@ -387,7 +387,6 @@ def dispatch_triggered_workflow(
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
quota_charge.commit()
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
@ -402,7 +401,7 @@ def dispatch_triggered_workflow(
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
|
||||
return dispatched_count
|
||||
return dispatched_count
|
||||
|
||||
|
||||
def dispatch_triggered_workflows(
|
||||
|
||||
@ -8,11 +8,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
@ -33,7 +32,6 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
TenantOwnerNotFoundError: If no owner/admin for tenant
|
||||
ScheduleExecutionError: If workflow trigger fails
|
||||
"""
|
||||
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
|
||||
with session_factory.create_session() as session:
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
@ -43,16 +41,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
if not tenant_owner:
|
||||
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
# Production dispatch: Trigger the workflow normally
|
||||
response = AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=tenant_owner,
|
||||
@ -63,10 +61,9 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
|
||||
@ -163,9 +163,11 @@ class DifyTestContainers:
|
||||
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
|
||||
logger.info("Redis container is ready and accepting connections")
|
||||
|
||||
# Start Dify Sandbox container for code execution environment.
|
||||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
# Use pinned version 0.2.12 to match production docker-compose configuration
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.14").with_network(self.network)
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
@ -185,7 +187,7 @@ class DifyTestContainers:
|
||||
# Start Dify Plugin Daemon container for plugin management
|
||||
# Dify Plugin Daemon provides plugin lifecycle management and execution
|
||||
logger.info("Initializing Dify Plugin Daemon container...")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network(
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network(
|
||||
self.network
|
||||
)
|
||||
self.dify_plugin_daemon.with_exposed_ports(5002)
|
||||
|
||||
@ -36,19 +36,12 @@ class TestAppGenerateService:
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
"reservation_id": "test-reservation-id",
|
||||
"available": 100,
|
||||
"reserved": 1,
|
||||
}
|
||||
mock_billing_service.quota_commit.return_value = {
|
||||
"available": 99,
|
||||
"reserved": 0,
|
||||
"refunded": 0,
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
@ -108,8 +101,6 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_quota_dify_config.BILLING_ENABLED = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
@ -127,7 +118,6 @@ class TestAppGenerateService:
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"quota_dify_config": mock_quota_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
@ -475,7 +465,6 @@ class TestAppGenerateService:
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
@ -489,10 +478,8 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing two-phase quota (reserve + commit)
|
||||
billing = mock_external_service_dependencies["billing_service"]
|
||||
billing.quota_reserve.assert_called_once()
|
||||
billing.quota_commit.assert_called_once()
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -1,517 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from enums.quota_type import QuotaType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class WebhookServiceRelationshipFactory:
|
||||
@staticmethod
|
||||
def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"webhook-{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Webhook App {uuid4()}",
|
||||
description="",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_workflow(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
account: Account,
|
||||
node_ids: list[str],
|
||||
version: str,
|
||||
) -> Workflow:
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": TRIGGER_WEBHOOK_NODE_TYPE,
|
||||
"title": f"Webhook {node_id}",
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [],
|
||||
"params": [],
|
||||
"body": [],
|
||||
"status_code": 200,
|
||||
"response_body": '{"status": "ok"}',
|
||||
"timeout": 30,
|
||||
},
|
||||
}
|
||||
for node_id in node_ids
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
workflow = Workflow(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
version=version,
|
||||
)
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def create_webhook_trigger(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
account: Account,
|
||||
node_id: str,
|
||||
webhook_id: str | None = None,
|
||||
) -> WorkflowWebhookTrigger:
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app.id,
|
||||
node_id=node_id,
|
||||
tenant_id=app.tenant_id,
|
||||
webhook_id=webhook_id or uuid4().hex[:24],
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(webhook_trigger)
|
||||
db_session_with_containers.commit()
|
||||
return webhook_trigger
|
||||
|
||||
@staticmethod
|
||||
def create_app_trigger(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
app: App,
|
||||
node_id: str,
|
||||
status: AppTriggerStatus,
|
||||
) -> AppTrigger:
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
node_id=node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
|
||||
provider_name="webhook",
|
||||
title=f"Webhook {node_id}",
|
||||
status=status,
|
||||
)
|
||||
db_session_with_containers.add(app_trigger)
|
||||
db_session_with_containers.commit()
|
||||
return app_trigger
|
||||
|
||||
|
||||
class TestWebhookServiceLookupWithContainers:
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="App trigger not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="rate limited"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="disabled"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["published-node"],
|
||||
version="2026-04-14.001",
|
||||
)
|
||||
draft_workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["debug-node"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="debug-node"
|
||||
)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_trigger.webhook_id,
|
||||
is_debug=True,
|
||||
)
|
||||
|
||||
assert got_trigger.id == webhook_trigger.id
|
||||
assert got_workflow.id == draft_workflow.id
|
||||
assert got_node_config["id"] == "debug-node"
|
||||
|
||||
|
||||
class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
end_user = SimpleNamespace(id=str(uuid4()))
|
||||
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
|
||||
|
||||
quota_charge = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=end_user,
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
) as mock_reserve,
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
|
||||
):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
mock_reserve.assert_called_once()
|
||||
reserve_args = mock_reserve.call_args.args
|
||||
assert reserve_args[0] == QuotaType.TRIGGER
|
||||
assert reserve_args[1] == webhook_trigger.tenant_id
|
||||
quota_charge.commit.assert_called_once()
|
||||
mock_trigger.assert_called_once()
|
||||
trigger_args = mock_trigger.call_args.args
|
||||
assert trigger_args[1] is end_user
|
||||
assert trigger_args[2].workflow_id == workflow.id
|
||||
assert trigger_args[2].root_node_id == webhook_trigger.node_id
|
||||
|
||||
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=SimpleNamespace(id=str(uuid4())),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited"
|
||||
) as mock_mark_rate_limited,
|
||||
):
|
||||
with pytest.raises(QuotaExceededError):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
|
||||
workflow,
|
||||
)
|
||||
|
||||
mock_mark_rate_limited.assert_called_once_with(tenant.id)
|
||||
|
||||
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
|
||||
workflow,
|
||||
)
|
||||
|
||||
mock_logger_exception.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceRelationshipSyncWithContainers:
|
||||
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)]
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="maximum webhook node limit"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
|
||||
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
stale_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_id="node-stale",
|
||||
webhook_id="stale-webhook-id-000001",
|
||||
)
|
||||
stale_trigger_id = stale_trigger.id
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["node-new"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001"
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
records = db_session_with_containers.scalars(
|
||||
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)
|
||||
).all()
|
||||
|
||||
assert [record.node_id for record in records] == ["node-new"]
|
||||
assert records[0].webhook_id == "new-webhook-id-000001"
|
||||
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
|
||||
|
||||
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
account=account,
|
||||
node_ids=["node-cache"],
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
)
|
||||
cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache"
|
||||
|
||||
with patch(
|
||||
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001"
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key)
|
||||
assert cached_payload is not None
|
||||
assert cached_payload["node_id"] == "node-cache"
|
||||
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
|
||||
|
||||
def test_sync_webhook_relationships_logs_when_lock_release_fails(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = RuntimeError("release failed")
|
||||
|
||||
with (
|
||||
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
|
||||
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
|
||||
):
|
||||
WebhookService.sync_webhook_relationships(app, workflow)
|
||||
|
||||
mock_logger_exception.assert_called_once()
|
||||
|
||||
|
||||
def _read_cache(cache_key: str) -> dict[str, str] | None:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cached = redis_client.get(cache_key)
|
||||
if not cached:
|
||||
return None
|
||||
if isinstance(cached, bytes):
|
||||
cached = cached.decode("utf-8")
|
||||
return json.loads(cached)
|
||||
|
||||
|
||||
WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache)
|
||||
@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from services import quota_service
|
||||
from enums import quota_type
|
||||
|
||||
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
@ -1,245 +0,0 @@
|
||||
"""Unit tests for inner_api app DSL import/export endpoints.
|
||||
|
||||
Tests Pydantic model validation, endpoint handler logic, and the
|
||||
_get_active_account helper. Auth/setup decorators are tested separately
|
||||
in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.app.dsl import (
|
||||
EnterpriseAppDSLExport,
|
||||
EnterpriseAppDSLImport,
|
||||
InnerAppDSLImportPayload,
|
||||
_get_active_account,
|
||||
)
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
class TestInnerAppDSLImportPayload:
|
||||
"""Test InnerAppDSLImportPayload Pydantic model validation."""
|
||||
|
||||
def test_valid_payload_all_fields(self):
|
||||
data = {
|
||||
"yaml_content": "version: 0.6.0\nkind: app\n",
|
||||
"creator_email": "user@example.com",
|
||||
"name": "My App",
|
||||
"description": "A test app",
|
||||
}
|
||||
payload = InnerAppDSLImportPayload.model_validate(data)
|
||||
assert payload.yaml_content == data["yaml_content"]
|
||||
assert payload.creator_email == "user@example.com"
|
||||
assert payload.name == "My App"
|
||||
assert payload.description == "A test app"
|
||||
|
||||
def test_valid_payload_optional_fields_omitted(self):
|
||||
data = {
|
||||
"yaml_content": "version: 0.6.0\n",
|
||||
"creator_email": "user@example.com",
|
||||
}
|
||||
payload = InnerAppDSLImportPayload.model_validate(data)
|
||||
assert payload.name is None
|
||||
assert payload.description is None
|
||||
|
||||
def test_missing_yaml_content_fails(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"})
|
||||
assert "yaml_content" in str(exc_info.value)
|
||||
|
||||
def test_missing_creator_email_fails(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
InnerAppDSLImportPayload.model_validate({"yaml_content": "test"})
|
||||
assert "creator_email" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGetActiveAccount:
|
||||
"""Test the _get_active_account helper function."""
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_active_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "active"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
result = _get_active_account("user@example.com")
|
||||
|
||||
assert result is mock_account
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_inactive_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "banned"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
result = _get_active_account("banned@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_nonexistent_email(self, mock_db):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = _get_active_account("missing@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEnterpriseAppDSLImport:
|
||||
"""Test EnterpriseAppDSLImport endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseAppDSLImport()
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_import_deps(self):
|
||||
"""Patch db, Session, and AppDslService for import handler tests."""
|
||||
with (
|
||||
patch("controllers.inner_api.app.dsl.db"),
|
||||
patch("controllers.inner_api.app.dsl.Session") as mock_session,
|
||||
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
|
||||
):
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
self._mock_dsl = MagicMock()
|
||||
mock_dsl_cls.return_value = self._mock_dsl
|
||||
yield
|
||||
|
||||
def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import":
|
||||
from services.app_dsl_service import Import
|
||||
|
||||
result = Import(
|
||||
id="import-id",
|
||||
status=status,
|
||||
app_id=kwargs.get("app_id", "app-123"),
|
||||
app_mode=kwargs.get("app_mode", "workflow"),
|
||||
)
|
||||
return result
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED)
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"yaml_content": "version: 0.6.0\n",
|
||||
"creator_email": "user@example.com",
|
||||
}
|
||||
result = unwrapped(api_instance, workspace_id="ws-123")
|
||||
|
||||
body, status_code = result
|
||||
assert status_code == 200
|
||||
assert body["status"] == "completed"
|
||||
mock_account.set_tenant_id.assert_called_once_with("ws-123")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask):
|
||||
mock_get_account.return_value = MagicMock()
|
||||
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING)
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"}
|
||||
body, status_code = unwrapped(api_instance, workspace_id="ws-123")
|
||||
|
||||
assert status_code == 202
|
||||
assert body["status"] == "pending"
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask):
|
||||
mock_get_account.return_value = MagicMock()
|
||||
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED)
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"}
|
||||
body, status_code = unwrapped(api_instance, workspace_id="ws-123")
|
||||
|
||||
assert status_code == 400
|
||||
assert body["status"] == "failed"
|
||||
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):
|
||||
mock_get_account.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"}
|
||||
result = unwrapped(api_instance, workspace_id="ws-123")
|
||||
|
||||
body, status_code = result
|
||||
assert status_code == 404
|
||||
assert "missing@e.com" in body["message"]
|
||||
|
||||
|
||||
class TestEnterpriseAppDSLExport:
|
||||
"""Test EnterpriseAppDSLExport endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseAppDSLExport()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.AppDslService")
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
result = unwrapped(api_instance, app_id="app-123")
|
||||
|
||||
body, status_code = result
|
||||
assert status_code == 200
|
||||
assert body["data"] == "version: 0.6.0\nkind: app\n"
|
||||
mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False)
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.AppDslService")
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "yaml-data"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=true"):
|
||||
result = unwrapped(api_instance, app_id="app-123")
|
||||
|
||||
body, status_code = result
|
||||
assert status_code == 200
|
||||
mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True)
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
result = unwrapped(api_instance, app_id="nonexistent")
|
||||
|
||||
body, status_code = result
|
||||
assert status_code == 404
|
||||
assert "app not found" in body["message"]
|
||||
@ -8,7 +8,6 @@ import core.app.apps.pipeline.pipeline_generator as module
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from models.enums import DataSourceType
|
||||
|
||||
|
||||
class FakeRagPipelineGenerateEntity(SimpleNamespace):
|
||||
@ -559,24 +558,6 @@ def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
|
||||
assert document.doc_metadata
|
||||
|
||||
|
||||
def test_build_document_supports_online_drive_datasource_type(generator):
|
||||
document = generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=True,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DRIVE,
|
||||
datasource_info={"id": "file-1", "bucket": "bucket-1", "name": "drive.pdf", "type": "file"},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
assert DataSourceType(document.data_source_type) == DataSourceType.ONLINE_DRIVE
|
||||
assert document.name == "drive.pdf"
|
||||
|
||||
|
||||
def test_build_document_invalid_datasource_type(generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator._build_document(
|
||||
|
||||
@ -115,12 +115,14 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
|
||||
"""Test deletion with exactly 1000 IDs triggers a single batch."""
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
@ -128,28 +130,11 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_splits_into_batches(self, vector_instance):
|
||||
"""Test deletion with >1000 IDs triggers multiple batched calls."""
|
||||
ids = [f"doc_{i}" for i in range(2500)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert vector_instance._client.delete.call_count == 3
|
||||
|
||||
batches = []
|
||||
for call in vector_instance._client.delete.call_args_list:
|
||||
filter_selector = call[1]["points_selector"]
|
||||
field_condition = filter_selector.filter.must[0]
|
||||
batches.append(field_condition.match.any)
|
||||
|
||||
assert len(batches[0]) == 1000
|
||||
assert len(batches[1]) == 1000
|
||||
assert len(batches[2]) == 500
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
@ -173,57 +158,3 @@ class TestTidbOnQdrantVectorDeleteByIds:
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
|
||||
|
||||
class TestInitVectorEndpointSelection:
|
||||
"""Test that init_vector selects the correct qdrant endpoint.
|
||||
|
||||
We avoid importing the full module (which triggers Flask app context)
|
||||
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
|
||||
"""
|
||||
|
||||
def test_uses_binding_endpoint_when_present(self):
|
||||
binding_endpoint = "https://qdrant-custom.tidb.com"
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-custom.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-custom.tidb.com"
|
||||
|
||||
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == "https://qdrant-global.tidb.com"
|
||||
|
||||
def test_falls_back_to_empty_when_both_none(self):
|
||||
binding_endpoint = None
|
||||
global_url = None
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == ""
|
||||
config = TidbOnQdrantConfig(endpoint=qdrant_url)
|
||||
assert config.endpoint == ""
|
||||
|
||||
def test_binding_endpoint_takes_precedence_over_global(self):
|
||||
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
|
||||
global_url = "https://qdrant-us-east.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
|
||||
|
||||
def test_empty_string_binding_endpoint_falls_back_to_global(self):
|
||||
binding_endpoint = ""
|
||||
global_url = "https://qdrant-global.tidb.com"
|
||||
|
||||
qdrant_url = binding_endpoint or global_url or ""
|
||||
|
||||
assert qdrant_url == "https://qdrant-global.tidb.com"
|
||||
|
||||
@ -1,308 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TestExtractQdrantEndpoint:
|
||||
"""Unit tests for TidbService.extract_qdrant_endpoint."""
|
||||
|
||||
def test_returns_endpoint_when_host_present(self):
|
||||
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
|
||||
result = TidbService.extract_qdrant_endpoint(response)
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
def test_returns_none_when_host_missing(self):
|
||||
response = {"endpoints": {"public": {}}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_public_missing(self):
|
||||
response = {"endpoints": {}}
|
||||
assert TidbService.extract_qdrant_endpoint(response) is None
|
||||
|
||||
def test_returns_none_when_endpoints_missing(self):
|
||||
assert TidbService.extract_qdrant_endpoint({}) is None
|
||||
|
||||
|
||||
class TestFetchQdrantEndpoint:
|
||||
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {
|
||||
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
|
||||
}
|
||||
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
|
||||
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = None
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_host_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
|
||||
mock_get_cluster.return_value = {}
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
def test_returns_none_on_exception(self, mock_get_cluster):
|
||||
mock_get_cluster.side_effect = RuntimeError("network error")
|
||||
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {
|
||||
"state": "ACTIVE",
|
||||
"userPrefix": "pfx",
|
||||
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
|
||||
}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
|
||||
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] is None
|
||||
|
||||
|
||||
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
|
||||
"""Verify that batch_create includes qdrant_endpoint per cluster."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
cluster_name = "abc123"
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = b"password123"
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
|
||||
|
||||
|
||||
class TestCreateTidbServerlessClusterRetry:
|
||||
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.side_effect = [
|
||||
{"state": "CREATING", "userPrefix": ""},
|
||||
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
|
||||
]
|
||||
|
||||
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is not None
|
||||
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
|
||||
assert mock_get_cluster.call_count == 2
|
||||
|
||||
@patch.object(TidbService, "get_tidb_serverless_cluster")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
|
||||
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
|
||||
|
||||
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.time.sleep"):
|
||||
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=400, text="Bad Request")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
|
||||
mock_http.post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 400"):
|
||||
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
|
||||
|
||||
|
||||
class TestBatchCreateEdgeCases:
|
||||
"""Cover logging/edge-case branches in batch_create."""
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_http.post.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
|
||||
)
|
||||
mock_redis.setex = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
mock_fetch_ep.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
|
||||
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
|
||||
mock_config.TIDB_SPEND_LIMIT = 10
|
||||
mock_response = MagicMock(status_code=500, text="Server Error")
|
||||
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
|
||||
mock_http.post.return_value = mock_response
|
||||
mock_redis.setex = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
TidbService.batch_create_tidb_serverless_cluster(
|
||||
batch_size=1,
|
||||
project_id="proj",
|
||||
api_url="url",
|
||||
iam_url="iam",
|
||||
public_key="pub",
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
|
||||
class TestBatchUpdateTidbServerlessClusterStatus:
|
||||
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
|
||||
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"clusters": [
|
||||
{
|
||||
"clusterId": "c-1",
|
||||
"state": "ACTIVE",
|
||||
"userPrefix": "pfx",
|
||||
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
|
||||
assert binding.status == TidbAuthBindingStatus.ACTIVE
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]
|
||||
},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
|
||||
assert binding.status == TidbAuthBindingStatus.ACTIVE
|
||||
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
|
||||
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
|
||||
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]
|
||||
},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint is None
|
||||
assert binding.status == TidbAuthBindingStatus.CREATING
|
||||
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
@ -1,349 +0,0 @@
|
||||
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from services.quota_service import QuotaCharge, QuotaService, unlimited
|
||||
|
||||
|
||||
class TestQuotaType:
|
||||
def test_billing_key_trigger(self):
|
||||
assert QuotaType.TRIGGER.billing_key == "trigger_event"
|
||||
|
||||
def test_billing_key_workflow(self):
|
||||
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
|
||||
|
||||
def test_billing_key_unlimited_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid quota type"):
|
||||
_ = QuotaType.UNLIMITED.billing_key
|
||||
|
||||
|
||||
class TestQuotaService:
|
||||
def test_reserve_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService"),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_reserve_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_reserve_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
|
||||
|
||||
assert charge.success is True
|
||||
assert charge.charge_id == "rid-1"
|
||||
assert charge._tenant_id == "t1"
|
||||
assert charge._feature_key == "trigger_event"
|
||||
assert charge._amount == 1
|
||||
mock_bs.quota_reserve.assert_called_once()
|
||||
|
||||
def test_reserve_no_reservation_id_raises(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {}
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_quota_exceeded_propagates(self):
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
|
||||
def test_reserve_api_exception_returns_unlimited(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.side_effect = RuntimeError("network")
|
||||
|
||||
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
|
||||
def test_consume_calls_reserve_and_commit(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
|
||||
assert charge.success is True
|
||||
mock_bs.quota_commit.assert_called_once()
|
||||
|
||||
def test_check_billing_disabled(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_check_zero_amount_raises(self):
|
||||
with patch("services.quota_service.dify_config") as mock_cfg:
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
|
||||
|
||||
def test_check_sufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=100),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
|
||||
|
||||
def test_check_insufficient_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=5),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
|
||||
|
||||
def test_check_unlimited_quota(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", return_value=-1),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
|
||||
|
||||
def test_check_exception_returns_true(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
|
||||
|
||||
def test_release_billing_disabled(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = False
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_empty_reservation(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_not_called()
|
||||
|
||||
def test_release_success(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.return_value = {}
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
mock_bs.quota_release.assert_called_once_with(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
|
||||
)
|
||||
|
||||
def test_release_exception_swallowed(self):
|
||||
with (
|
||||
patch("services.quota_service.dify_config") as mock_cfg,
|
||||
patch("services.billing_service.BillingService") as mock_bs,
|
||||
):
|
||||
mock_cfg.BILLING_ENABLED = True
|
||||
mock_bs.quota_release.side_effect = RuntimeError("fail")
|
||||
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_get_remaining_normal(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
|
||||
|
||||
def test_get_remaining_unlimited(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_over_limit_returns_zero(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_exception_returns_neg1(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.side_effect = RuntimeError
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
|
||||
|
||||
def test_get_remaining_empty_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_non_dict_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = "invalid"
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
def test_get_remaining_feature_not_in_response(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
|
||||
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_non_dict_feature_info(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
|
||||
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
|
||||
|
||||
|
||||
class TestQuotaCharge:
|
||||
def test_commit_success(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
feature_key="trigger_event",
|
||||
reservation_id="rid-1",
|
||||
actual_amount=1,
|
||||
)
|
||||
assert charge._committed is True
|
||||
|
||||
def test_commit_with_actual_amount(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=10,
|
||||
)
|
||||
charge.commit(actual_amount=5)
|
||||
call_kwargs = mock_bs.quota_commit.call_args[1]
|
||||
assert call_kwargs["actual_amount"] == 5
|
||||
|
||||
def test_commit_idempotent(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.return_value = {}
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
charge.commit()
|
||||
assert mock_bs.quota_commit.call_count == 1
|
||||
|
||||
def test_commit_no_charge_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_no_tenant_id_noop(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.commit()
|
||||
mock_bs.quota_commit.assert_not_called()
|
||||
|
||||
def test_commit_exception_swallowed(self):
|
||||
with patch("services.billing_service.BillingService") as mock_bs:
|
||||
mock_bs.quota_commit.side_effect = RuntimeError("fail")
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
_amount=1,
|
||||
)
|
||||
charge.commit()
|
||||
|
||||
def test_refund_success(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id="t1",
|
||||
_feature_key="trigger_event",
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
|
||||
|
||||
def test_refund_no_charge_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
def test_refund_no_tenant_id_noop(self):
|
||||
with patch.object(QuotaService, "release") as mock_rel:
|
||||
charge = QuotaCharge(
|
||||
success=True,
|
||||
charge_id="rid-1",
|
||||
_quota_type=QuotaType.TRIGGER,
|
||||
_tenant_id=None,
|
||||
)
|
||||
charge.refund()
|
||||
mock_rel.assert_not_called()
|
||||
|
||||
|
||||
class TestUnlimited:
|
||||
def test_unlimited_returns_success_with_no_charge_id(self):
|
||||
charge = unlimited()
|
||||
assert charge.success is True
|
||||
assert charge.charge_id is None
|
||||
assert charge._quota_type == QuotaType.UNLIMITED
|
||||
@ -23,7 +23,6 @@ import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from enums.quota_type import QuotaType
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@ -448,8 +447,8 @@ class TestGenerateBilling:
|
||||
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
reserve_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
consume_mock = mocker.patch(
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
@ -468,8 +467,7 @@ class TestGenerateBilling:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
|
||||
quota_charge.commit.assert_called_once()
|
||||
consume_mock.assert_called_once_with("tenant-id")
|
||||
|
||||
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
|
||||
from services.errors.app import QuotaExceededError
|
||||
@ -477,7 +475,7 @@ class TestGenerateBilling:
|
||||
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
|
||||
)
|
||||
|
||||
@ -494,7 +492,7 @@ class TestGenerateBilling:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
|
||||
quota_charge = MagicMock()
|
||||
mocker.patch(
|
||||
"services.app_generate_service.QuotaService.reserve",
|
||||
"services.app_generate_service.QuotaType.WORKFLOW.consume",
|
||||
return_value=quota_charge,
|
||||
)
|
||||
mocker.patch(
|
||||
|
||||
@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_service: QuotaService mock
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_quota_service = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaService",
|
||||
new=mock_quota_service,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_service": mock_quota_service,
|
||||
"quota_workflow": quota_workflow,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
@ -141,9 +146,6 @@ class TestAsyncWorkflowService:
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
quota_charge_mock = MagicMock()
|
||||
mocks["quota_service"].reserve.return_value = quota_charge_mock
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
@ -161,9 +163,8 @@ class TestAsyncWorkflowService:
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_service"].reserve.assert_called_once()
|
||||
quota_charge_mock.commit.assert_called_once()
|
||||
assert session.commit.call_count == 3
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.status == WorkflowTriggerStatus.QUEUED
|
||||
@ -249,7 +250,7 @@ class TestAsyncWorkflowService:
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
@ -266,7 +267,7 @@ class TestAsyncWorkflowService:
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
assert session.commit.call_count == 3
|
||||
assert session.commit.call_count == 2
|
||||
updated_log = mocks["repo"].update.call_args[0][0]
|
||||
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
|
||||
assert "Quota limit reached" in updated_log.error
|
||||
@ -462,7 +463,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None)
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
|
||||
workflow_service.get_published_workflow.assert_not_called()
|
||||
|
||||
def test_should_raise_when_specific_workflow_id_not_found(self):
|
||||
@ -490,7 +491,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model)
|
||||
workflow_service.get_published_workflow_by_id.assert_not_called()
|
||||
|
||||
def test_should_raise_when_default_published_workflow_not_found(self):
|
||||
|
||||
@ -290,19 +290,9 @@ class TestBillingServiceSubscriptionInfo:
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "professional", "interval": "month", "education": False},
|
||||
"members": {"size": 1, "limit": 50},
|
||||
"apps": {"size": 1, "limit": 200},
|
||||
"vector_space": {"size": 0.0, "limit": 20480},
|
||||
"knowledge_rate_limit": {"limit": 1000},
|
||||
"documents_upload_quota": {"size": 0, "limit": 1000},
|
||||
"annotation_quota_limit": {"size": 0, "limit": 5000},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": 1775952000,
|
||||
"subscription_plan": "professional",
|
||||
"billing_cycle": "monthly",
|
||||
"status": "active",
|
||||
}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
@ -425,7 +415,7 @@ class TestBillingServiceUsageCalculation:
|
||||
yield mock
|
||||
|
||||
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
|
||||
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
|
||||
"""Test retrieval of tenant feature plan usage information."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
|
||||
@ -438,20 +428,6 @@ class TestBillingServiceUsageCalculation:
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_get_quota_info(self, mock_send_request):
|
||||
"""Test retrieval of quota info from new endpoint."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with positive delta (adding credits)."""
|
||||
# Arrange
|
||||
@ -529,150 +505,6 @@ class TestBillingServiceUsageCalculation:
|
||||
)
|
||||
|
||||
|
||||
class TestBillingServiceQuotaOperations:
|
||||
"""Unit tests for quota reserve/commit/release operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
def test_quota_reserve_success(self, mock_send_request):
|
||||
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/reserve",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
|
||||
)
|
||||
|
||||
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
|
||||
|
||||
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
|
||||
|
||||
assert result["available"] == 99
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["reserved"] == 1
|
||||
assert isinstance(result["reserved"], int)
|
||||
|
||||
def test_quota_reserve_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
|
||||
meta = {"source": "webhook"}
|
||||
|
||||
BillingService.quota_reserve(
|
||||
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"source": "webhook"}
|
||||
|
||||
def test_quota_commit_success(self, mock_send_request):
|
||||
expected = {"available": 98, "reserved": 0, "refunded": 0}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/commit",
|
||||
json={
|
||||
"tenant_id": "t1",
|
||||
"feature_key": "trigger_event",
|
||||
"reservation_id": "rid-1",
|
||||
"actual_amount": 1,
|
||||
},
|
||||
)
|
||||
|
||||
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
|
||||
|
||||
result = BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
|
||||
)
|
||||
|
||||
assert result["available"] == 97
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["refunded"] == 1
|
||||
assert isinstance(result["refunded"], int)
|
||||
|
||||
def test_quota_commit_with_meta(self, mock_send_request):
|
||||
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
|
||||
meta = {"reason": "partial"}
|
||||
|
||||
BillingService.quota_commit(
|
||||
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
|
||||
)
|
||||
|
||||
call_json = mock_send_request.call_args[1]["json"]
|
||||
assert call_json["meta"] == {"reason": "partial"}
|
||||
|
||||
def test_quota_release_success(self, mock_send_request):
|
||||
expected = {"available": 100, "reserved": 0, "released": 1}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
|
||||
|
||||
assert result == expected
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
|
||||
)
|
||||
|
||||
def test_quota_release_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int."""
|
||||
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
|
||||
|
||||
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
|
||||
|
||||
assert result["available"] == 100
|
||||
assert isinstance(result["available"], int)
|
||||
assert result["released"] == 1
|
||||
assert isinstance(result["released"], int)
|
||||
|
||||
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
|
||||
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
|
||||
mock_send_request.return_value = {
|
||||
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
|
||||
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
|
||||
}
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert isinstance(result["trigger_event"]["usage"], int)
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert isinstance(result["trigger_event"]["limit"], int)
|
||||
assert result["trigger_event"]["reset_date"] == 1700000000
|
||||
assert isinstance(result["trigger_event"]["reset_date"], int)
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
assert isinstance(result["api_rate_limit"]["limit"], int)
|
||||
|
||||
def test_get_quota_info_accepts_int_values(self, mock_send_request):
|
||||
"""Test that get_quota_info works with native int values."""
|
||||
expected = {
|
||||
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
|
||||
"api_rate_limit": {"usage": 0, "limit": -1},
|
||||
}
|
||||
mock_send_request.return_value = expected
|
||||
|
||||
result = BillingService.get_quota_info("t1")
|
||||
|
||||
assert result["trigger_event"]["usage"] == 42
|
||||
assert result["trigger_event"]["limit"] == 3000
|
||||
assert result["api_rate_limit"]["limit"] == -1
|
||||
|
||||
|
||||
class TestBillingServiceRateLimitEnforcement:
|
||||
"""Unit tests for rate limit enforcement mechanisms.
|
||||
|
||||
@ -1177,14 +1009,17 @@ class TestBillingServiceEdgeCases:
|
||||
yield mock
|
||||
|
||||
def test_get_info_empty_response(self, mock_send_request):
|
||||
"""Empty response from billing API should raise ValidationError due to missing required fields."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
"""Test handling of empty billing info response."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-empty"
|
||||
mock_send_request.return_value = {}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
BillingService.get_info(tenant_id)
|
||||
# Act
|
||||
result = BillingService.get_info(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_send_request.assert_called_once()
|
||||
|
||||
def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
|
||||
"""Test updating tenant feature usage with zero delta (no change)."""
|
||||
@ -1599,21 +1434,12 @@ class TestBillingServiceIntegrationScenarios:
|
||||
|
||||
# Step 1: Get current billing info
|
||||
mock_send_request.return_value = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "sandbox", "interval": "", "education": False},
|
||||
"members": {"size": 0, "limit": 1},
|
||||
"apps": {"size": 0, "limit": 5},
|
||||
"vector_space": {"size": 0.0, "limit": 50},
|
||||
"knowledge_rate_limit": {"limit": 10},
|
||||
"documents_upload_quota": {"size": 0, "limit": 50},
|
||||
"annotation_quota_limit": {"size": 0, "limit": 10},
|
||||
"docs_processing": "standard",
|
||||
"can_replace_logo": False,
|
||||
"model_load_balancing_enabled": False,
|
||||
"knowledge_pipeline_publish_enabled": False,
|
||||
"subscription_plan": "sandbox",
|
||||
"billing_cycle": "monthly",
|
||||
"status": "active",
|
||||
}
|
||||
current_info = BillingService.get_info(tenant_id)
|
||||
assert current_info["subscription"]["plan"] == "sandbox"
|
||||
assert current_info["subscription_plan"] == "sandbox"
|
||||
|
||||
# Step 2: Get payment link for upgrade
|
||||
mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
|
||||
@ -1727,140 +1553,3 @@ class TestBillingServiceIntegrationScenarios:
|
||||
mock_send_request.return_value = {"result": "success", "activated": True}
|
||||
activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
|
||||
assert activate_result["activated"] is True
|
||||
|
||||
|
||||
class TestBillingServiceSubscriptionInfoDataType:
|
||||
"""Unit tests for data type coercion in BillingService.get_info
|
||||
|
||||
1. Verifies the get_info returns correct Python types for numeric fields
|
||||
2. Ensure the compatibility regardless of what results the upstream billing API returns
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send_request(self):
|
||||
with patch.object(BillingService, "_send_request") as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def normal_billing_response(self) -> dict:
|
||||
return {
|
||||
"enabled": True,
|
||||
"subscription": {
|
||||
"plan": "team",
|
||||
"interval": "year",
|
||||
"education": False,
|
||||
},
|
||||
"members": {"size": 10, "limit": 50},
|
||||
"apps": {"size": 80, "limit": 200},
|
||||
"vector_space": {"size": 5120.75, "limit": 20480},
|
||||
"knowledge_rate_limit": {"limit": 1000},
|
||||
"documents_upload_quota": {"size": 450, "limit": 1000},
|
||||
"annotation_quota_limit": {"size": 1200, "limit": 5000},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": 1745971200,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def string_billing_response(self) -> dict:
|
||||
return {
|
||||
"enabled": True,
|
||||
"subscription": {
|
||||
"plan": "team",
|
||||
"interval": "year",
|
||||
"education": False,
|
||||
},
|
||||
"members": {"size": "10", "limit": "50"},
|
||||
"apps": {"size": "80", "limit": "200"},
|
||||
"vector_space": {"size": 5120.75, "limit": "20480"},
|
||||
"knowledge_rate_limit": {"limit": "1000"},
|
||||
"documents_upload_quota": {"size": "450", "limit": "1000"},
|
||||
"annotation_quota_limit": {"size": "1200", "limit": "5000"},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
"next_credit_reset_date": "1745971200",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _assert_billing_info_types(result: dict):
|
||||
assert isinstance(result["enabled"], bool)
|
||||
assert isinstance(result["subscription"]["plan"], str)
|
||||
assert isinstance(result["subscription"]["interval"], str)
|
||||
assert isinstance(result["subscription"]["education"], bool)
|
||||
|
||||
assert isinstance(result["members"]["size"], int)
|
||||
assert isinstance(result["members"]["limit"], int)
|
||||
|
||||
assert isinstance(result["apps"]["size"], int)
|
||||
assert isinstance(result["apps"]["limit"], int)
|
||||
|
||||
assert isinstance(result["vector_space"]["size"], float)
|
||||
assert isinstance(result["vector_space"]["limit"], int)
|
||||
|
||||
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
|
||||
|
||||
assert isinstance(result["documents_upload_quota"]["size"], int)
|
||||
assert isinstance(result["documents_upload_quota"]["limit"], int)
|
||||
|
||||
assert isinstance(result["annotation_quota_limit"]["size"], int)
|
||||
assert isinstance(result["annotation_quota_limit"]["limit"], int)
|
||||
|
||||
assert isinstance(result["docs_processing"], str)
|
||||
assert isinstance(result["can_replace_logo"], bool)
|
||||
assert isinstance(result["model_load_balancing_enabled"], bool)
|
||||
assert isinstance(result["knowledge_pipeline_publish_enabled"], bool)
|
||||
if "next_credit_reset_date" in result:
|
||||
assert isinstance(result["next_credit_reset_date"], int)
|
||||
|
||||
def test_get_info_with_normal_types(self, mock_send_request, normal_billing_response):
|
||||
"""When the billing API returns native numeric types, get_info should preserve them."""
|
||||
mock_send_request.return_value = normal_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
self._assert_billing_info_types(result)
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
|
||||
|
||||
def test_get_info_with_string_types(self, mock_send_request, string_billing_response):
|
||||
"""When the billing API returns numeric values as strings, get_info should coerce them."""
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
self._assert_billing_info_types(result)
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
|
||||
|
||||
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
|
||||
"""NotRequired fields can be absent without raising."""
|
||||
del string_billing_response["next_credit_reset_date"]
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
assert "next_credit_reset_date" not in result
|
||||
self._assert_billing_info_types(result)
|
||||
|
||||
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
|
||||
"""Undefined fields are silently stripped by validate_python."""
|
||||
string_billing_response["new_feature"] = "something"
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
# extra fields are dropped by TypeAdapter on TypedDict
|
||||
assert "new_feature" not in result
|
||||
self._assert_billing_info_types(result)
|
||||
|
||||
def test_get_info_missing_required_field_raises(self, mock_send_request, string_billing_response):
|
||||
"""Missing a required field should raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
del string_billing_response["members"]
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
BillingService.get_info("tenant-type-test")
|
||||
|
||||
@ -337,7 +337,10 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_draft_workflow(app)
|
||||
|
||||
@ -347,7 +350,10 @@ class TestWorkflowService:
|
||||
"""Test get_draft_workflow returns None when no draft exists."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
# Mock database query to return None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
result = workflow_service.get_draft_workflow(app)
|
||||
|
||||
@ -359,7 +365,10 @@ class TestWorkflowService:
|
||||
workflow_id = "workflow-123"
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
|
||||
|
||||
@ -374,7 +383,10 @@ class TestWorkflowService:
|
||||
workflow_id = "workflow-123"
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
|
||||
@ -393,7 +405,10 @@ class TestWorkflowService:
|
||||
workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
|
||||
)
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
with pytest.raises(IsDraftWorkflowError):
|
||||
workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
@ -403,7 +418,10 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
workflow_id = "nonexistent-workflow"
|
||||
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
# Mock database query to return None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
|
||||
|
||||
@ -415,7 +433,10 @@ class TestWorkflowService:
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.get_published_workflow(app)
|
||||
|
||||
@ -444,7 +465,11 @@ class TestWorkflowService:
|
||||
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
|
||||
features = {"file_upload": {"enabled": False}}
|
||||
|
||||
mock_db_session.session.scalar.return_value = None
|
||||
# Mock get_draft_workflow to return None (no existing draft)
|
||||
# This simulates the first time a workflow is created for an app
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(workflow_service, "validate_features_structure"),
|
||||
@ -481,7 +506,9 @@ class TestWorkflowService:
|
||||
# Mock existing draft workflow
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
with (
|
||||
patch.object(workflow_service, "validate_features_structure"),
|
||||
@ -520,7 +547,9 @@ class TestWorkflowService:
|
||||
# Mock existing draft workflow with different hash
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
|
||||
|
||||
mock_db_session.session.scalar.return_value = mock_workflow
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
mock_query.where.return_value.first.return_value = mock_workflow
|
||||
|
||||
with pytest.raises(WorkflowHashNotEqualError):
|
||||
workflow_service.sync_draft_workflow(
|
||||
|
||||
@ -1,204 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import tasks.trigger_processing_tasks as trigger_processing_tasks_module
|
||||
from services.errors.app import QuotaExceededError
|
||||
from tasks.trigger_processing_tasks import dispatch_triggered_workflow
|
||||
|
||||
|
||||
class TestDispatchTriggeredWorkflow:
|
||||
"""Unit tests covering branch behaviours of ``dispatch_triggered_workflow``.
|
||||
|
||||
The covered branches are:
|
||||
- workflow missing for ``plugin_trigger.app_id`` → log + ``continue``
|
||||
- ``QuotaService.reserve`` raising ``QuotaExceededError`` →
|
||||
``mark_tenant_triggers_rate_limited`` + early ``return``
|
||||
- ``trigger_workflow_async`` succeeds →
|
||||
``quota_charge.commit()`` + ``dispatched_count`` increments
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self):
|
||||
sub = MagicMock()
|
||||
sub.id = "subscription-123"
|
||||
sub.tenant_id = "tenant-123"
|
||||
sub.provider_id = "langgenius/test_plugin/test_plugin"
|
||||
sub.endpoint_id = "endpoint-123"
|
||||
sub.credentials = {}
|
||||
sub.credential_type = "api_key"
|
||||
return sub
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_trigger(self):
|
||||
trigger = MagicMock()
|
||||
trigger.id = "plugin-trigger-123"
|
||||
trigger.app_id = "app-123"
|
||||
trigger.node_id = "node-123"
|
||||
return trigger
|
||||
|
||||
@pytest.fixture
|
||||
def provider_controller(self):
|
||||
controller = MagicMock()
|
||||
controller.plugin_unique_identifier = "langgenius/test_plugin:0.0.1"
|
||||
controller.entity.identity.name = "Test Plugin"
|
||||
controller.entity.identity.icon = "icon.svg"
|
||||
controller.entity.identity.icon_dark = "icon_dark.svg"
|
||||
return controller
|
||||
|
||||
@pytest.fixture
|
||||
def dispatch_mocks(self, subscription, plugin_trigger, provider_controller):
|
||||
"""Patch all external dependencies reached by ``dispatch_triggered_workflow``.
|
||||
|
||||
Defaults are configured so the code flow can reach the final async
|
||||
trigger block (line ~385); each test overrides specific handles
|
||||
(``get_workflows``, ``reserve``, ``create_end_user_batch``, ...) to
|
||||
drive the path it targets.
|
||||
"""
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = MagicMock()
|
||||
session_cm.__exit__.return_value = False
|
||||
|
||||
invoke_response = MagicMock()
|
||||
invoke_response.cancelled = False
|
||||
invoke_response.variables = {}
|
||||
|
||||
quota_charge = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
|
||||
"get_request",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
|
||||
"get_payload",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerSubscriptionOperatorService,
|
||||
"get_subscriber_triggers",
|
||||
return_value=[plugin_trigger],
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerManager,
|
||||
"get_trigger_provider",
|
||||
return_value=provider_controller,
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerManager,
|
||||
"invoke_trigger_event",
|
||||
return_value=invoke_response,
|
||||
) as invoke_trigger_event,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.TriggerEventNodeData,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module,
|
||||
"_get_latest_workflows_by_app_ids",
|
||||
) as get_workflows,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.EndUserService,
|
||||
"create_end_user_batch",
|
||||
return_value={},
|
||||
) as create_end_user_batch,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.session_factory,
|
||||
"create_session",
|
||||
return_value=session_cm,
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.QuotaService,
|
||||
"reserve",
|
||||
return_value=quota_charge,
|
||||
) as reserve,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.AppTriggerService,
|
||||
"mark_tenant_triggers_rate_limited",
|
||||
) as mark_rate_limited,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
) as trigger_workflow_async,
|
||||
):
|
||||
yield {
|
||||
"get_workflows": get_workflows,
|
||||
"reserve": reserve,
|
||||
"quota_charge": quota_charge,
|
||||
"mark_rate_limited": mark_rate_limited,
|
||||
"invoke_trigger_event": invoke_trigger_event,
|
||||
"invoke_response": invoke_response,
|
||||
"create_end_user_batch": create_end_user_batch,
|
||||
"trigger_workflow_async": trigger_workflow_async,
|
||||
}
|
||||
|
||||
def test_dispatch_skips_when_workflow_missing(self, subscription, dispatch_mocks):
|
||||
"""Covers missing workflow → log + ``continue``."""
|
||||
dispatch_mocks["get_workflows"].return_value = {}
|
||||
|
||||
dispatched = dispatch_triggered_workflow(
|
||||
user_id="user-123",
|
||||
subscription=subscription,
|
||||
event_name="test_event",
|
||||
request_id="request-123",
|
||||
)
|
||||
|
||||
assert dispatched == 0
|
||||
dispatch_mocks["reserve"].assert_not_called()
|
||||
dispatch_mocks["invoke_trigger_event"].assert_not_called()
|
||||
dispatch_mocks["mark_rate_limited"].assert_not_called()
|
||||
|
||||
def test_dispatch_marks_rate_limited_when_quota_exceeded(self, subscription, plugin_trigger, dispatch_mocks):
|
||||
"""Covers QuotaExceededError → mark rate-limited + early return."""
|
||||
workflow_mock = MagicMock()
|
||||
workflow_mock.walk_nodes.return_value = iter(
|
||||
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
|
||||
)
|
||||
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
|
||||
dispatch_mocks["reserve"].side_effect = QuotaExceededError(
|
||||
feature="trigger", tenant_id=subscription.tenant_id, required=1
|
||||
)
|
||||
|
||||
dispatched = dispatch_triggered_workflow(
|
||||
user_id="user-123",
|
||||
subscription=subscription,
|
||||
event_name="test_event",
|
||||
request_id="request-123",
|
||||
)
|
||||
|
||||
assert dispatched == 0
|
||||
dispatch_mocks["reserve"].assert_called_once()
|
||||
dispatch_mocks["mark_rate_limited"].assert_called_once_with(subscription.tenant_id)
|
||||
dispatch_mocks["invoke_trigger_event"].assert_not_called()
|
||||
|
||||
def test_dispatch_commits_quota_and_counts_when_workflow_triggered(
|
||||
self, subscription, plugin_trigger, dispatch_mocks
|
||||
):
|
||||
"""Happy path: end user exists and async trigger succeeds."""
|
||||
workflow_mock = MagicMock()
|
||||
workflow_mock.id = "workflow-123"
|
||||
workflow_mock.walk_nodes.return_value = iter(
|
||||
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
|
||||
)
|
||||
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
|
||||
|
||||
end_user_mock = MagicMock()
|
||||
dispatch_mocks["create_end_user_batch"].return_value = {plugin_trigger.app_id: end_user_mock}
|
||||
|
||||
dispatched = dispatch_triggered_workflow(
|
||||
user_id="user-123",
|
||||
subscription=subscription,
|
||||
event_name="test_event",
|
||||
request_id="request-123",
|
||||
)
|
||||
|
||||
assert dispatched == 1
|
||||
dispatch_mocks["trigger_workflow_async"].assert_called_once()
|
||||
_, kwargs = dispatch_mocks["trigger_workflow_async"].call_args
|
||||
assert kwargs["user"] is end_user_mock
|
||||
dispatch_mocks["quota_charge"].commit.assert_called_once()
|
||||
dispatch_mocks["quota_charge"].refund.assert_not_called()
|
||||
dispatch_mocks["mark_rate_limited"].assert_not_called()
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@ -1457,7 +1457,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.13.3"
|
||||
version = "1.13.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.3
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -245,7 +245,7 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:0.2.14
|
||||
image: langgenius/dify-sandbox:0.2.12
|
||||
restart: always
|
||||
environment:
|
||||
# The DifySandbox configurations
|
||||
@ -269,7 +269,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.5.3-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.4-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
||||
@ -97,7 +97,7 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:0.2.14
|
||||
image: langgenius/dify-sandbox:0.2.12
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
@ -123,7 +123,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.5.3-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.4-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
||||
@ -731,7 +731,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -773,7 +773,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -812,7 +812,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.3
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -842,7 +842,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.3
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -955,7 +955,7 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:0.2.14
|
||||
image: langgenius/dify-sandbox:0.2.12
|
||||
restart: always
|
||||
environment:
|
||||
# The DifySandbox configurations
|
||||
@ -979,7 +979,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.5.3-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.4-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
||||
@ -5,8 +5,7 @@ app:
|
||||
max_workers: 4
|
||||
max_requests: 50
|
||||
worker_timeout: 5
|
||||
python_path: /opt/python/bin/python3
|
||||
nodejs_path: /usr/local/bin/node
|
||||
python_path: /usr/local/bin/python3
|
||||
enable_network: True # please make sure there is no network risk in your environment
|
||||
allowed_syscalls: # please leave it empty if you have no idea how seccomp works
|
||||
proxy:
|
||||
|
||||
@ -5,7 +5,7 @@ app:
|
||||
max_workers: 4
|
||||
max_requests: 50
|
||||
worker_timeout: 5
|
||||
python_path: /opt/python/bin/python3
|
||||
python_path: /usr/local/bin/python3
|
||||
python_lib_path:
|
||||
- /usr/local/lib/python3.10
|
||||
- /usr/lib/python3.10
|
||||
|
||||
@ -501,16 +501,6 @@ describe('Question component', () => {
|
||||
expect(onRegenerate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render default question avatar icon when questionIcon is not provided', () => {
|
||||
const { container } = renderWithProvider(
|
||||
makeItem(),
|
||||
vi.fn() as unknown as OnRegenerate,
|
||||
)
|
||||
|
||||
const defaultIcon = container.querySelector('.question-default-user-icon')
|
||||
expect(defaultIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render custom questionIcon when provided', () => {
|
||||
const { container } = renderWithProvider(
|
||||
makeItem(),
|
||||
@ -519,7 +509,7 @@ describe('Question component', () => {
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('custom-question-icon')).toBeInTheDocument()
|
||||
const defaultIcon = container.querySelector('.question-default-user-icon')
|
||||
const defaultIcon = container.querySelector('.i-custom-public-avatar-user')
|
||||
expect(defaultIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Textarea from 'react-textarea-autosize'
|
||||
import { FileList } from '@/app/components/base/file-uploader'
|
||||
import { User } from '@/app/components/base/icons/src/public/avatar'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ActionButton from '../../action-button'
|
||||
@ -244,7 +243,7 @@ const Question: FC<QuestionProps> = ({
|
||||
{
|
||||
questionIcon || (
|
||||
<div className="h-full w-full rounded-full border-[0.5px] border-black/5">
|
||||
<User className="question-default-user-icon h-full w-full" />
|
||||
<div className="i-custom-public-avatar-user h-full w-full" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -142,7 +142,7 @@ const ApiKeyModal = ({
|
||||
onExtraButtonClick={onRemove}
|
||||
disabled={disabled || isLoading || doingAction}
|
||||
clickOutsideNotClose={true}
|
||||
wrapperClassName="!z-[1002]"
|
||||
wrapperClassName="!z-[101]"
|
||||
>
|
||||
{pluginPayload.detail && (
|
||||
<ReadmeEntrance pluginDetail={pluginPayload.detail} showType={ReadmeShowType.modal} />
|
||||
|
||||
@ -157,7 +157,7 @@ const OAuthClientSettings = ({
|
||||
)
|
||||
}
|
||||
containerClassName="pt-0"
|
||||
wrapperClassName="!z-[1002]"
|
||||
wrapperClassName="!z-[101]"
|
||||
clickOutsideNotClose={true}
|
||||
>
|
||||
{pluginPayload.detail && (
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.13.3",
|
||||
"version": "1.13.2",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.32.1",
|
||||
"imports": {
|
||||
@ -125,15 +125,15 @@
|
||||
"mime": "4.1.0",
|
||||
"mitt": "3.0.1",
|
||||
"negotiator": "1.0.0",
|
||||
"next": "16.2.3",
|
||||
"next": "16.2.1",
|
||||
"next-themes": "0.4.6",
|
||||
"nuqs": "2.8.9",
|
||||
"pinyin-pro": "3.28.0",
|
||||
"qrcode.react": "4.2.0",
|
||||
"qs": "6.15.0",
|
||||
"react": "19.2.5",
|
||||
"react": "19.2.4",
|
||||
"react-18-input-autosize": "3.0.0",
|
||||
"react-dom": "19.2.5",
|
||||
"react-dom": "19.2.4",
|
||||
"react-easy-crop": "5.5.6",
|
||||
"react-hotkeys-hook": "5.2.4",
|
||||
"react-i18next": "16.6.1",
|
||||
@ -173,8 +173,8 @@
|
||||
"@mdx-js/loader": "3.1.1",
|
||||
"@mdx-js/react": "3.1.1",
|
||||
"@mdx-js/rollup": "3.1.1",
|
||||
"@next/eslint-plugin-next": "16.2.3",
|
||||
"@next/mdx": "16.2.3",
|
||||
"@next/eslint-plugin-next": "16.2.1",
|
||||
"@next/mdx": "16.2.1",
|
||||
"@rgrove/parse-xml": "4.2.0",
|
||||
"@storybook/addon-docs": "10.3.1",
|
||||
"@storybook/addon-links": "10.3.1",
|
||||
@ -231,7 +231,7 @@
|
||||
"nock": "14.0.11",
|
||||
"postcss": "8.5.8",
|
||||
"postcss-js": "5.1.0",
|
||||
"react-server-dom-webpack": "19.2.5",
|
||||
"react-server-dom-webpack": "19.2.4",
|
||||
"sass": "1.98.0",
|
||||
"storybook": "10.3.1",
|
||||
"tailwindcss": "3.4.19",
|
||||
|
||||
1066
web/pnpm-lock.yaml
generated
1066
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user