Compare commits

..

36 Commits

Author SHA1 Message Date
f7ae14d50e test: refactor GenericTable tests to use a helper function for option selection 2026-03-25 18:09:38 +08:00
8e093a2b25 test: update GenericTable tests to use button interactions for method selection 2026-03-25 17:41:00 +08:00
95bdeb674f Merge remote-tracking branch 'origin/main' into test/workflow-app 2026-03-25 17:28:06 +08:00
b8b0422e73 fix: enhance form validation for file inputs and improve handling of empty array values in variable modal 2026-03-25 15:47:53 +08:00
d5870d2620 test: improve performance and error handling in variable modal and translation tests 2026-03-25 15:37:18 +08:00
3200c574a6 [autofix.ci] apply automated fixes 2026-03-25 07:25:39 +00:00
ba0c911011 Merge branch 'main' into test/workflow-app 2026-03-25 15:22:30 +08:00
77e7f0a7de Merge branch 'test/workflow-part-8' into test/workflow-app 2026-03-25 15:21:27 +08:00
74e5ac4153 refactor(tests): update toast mock implementation to use new UI toast structure 2026-03-25 13:59:52 +08:00
c494f80452 [autofix.ci] apply automated fixes 2026-03-25 05:47:45 +00:00
fb4d51e750 Merge remote-tracking branch 'origin/main' into test/workflow-part-8 2026-03-25 13:44:41 +08:00
a119726469 test: migrate workflow app service tests to testcontainers (#34036)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-25 13:41:21 +08:00
yyh
8876f69c24 refactor(workflow): migrate legacy toast usage to ui toast (#34002)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-25 13:41:20 +08:00
1f8bf024e7 test: migrate tools transform service tests to testcontainers (#34035)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-25 13:33:02 +08:00
86cda295f0 fix: resolve SADeprecationWarning for callable default in remaining TypeBase models (#34049) 2026-03-25 13:33:02 +08:00
a1e6c3ee77 test: migrate app service tests to testcontainers (#34025)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-25 13:33:02 +08:00
b45f056f62 refactor: select in console datasets document controller (#34029)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-25 13:33:02 +08:00
b88e4a5e9c refactor: select in console datasets segments and API key controllers (#34027)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-25 13:33:02 +08:00
42a644cedb test: migrate advanced prompt template service tests to testcontainers (#34034) 2026-03-25 13:33:01 +08:00
0c01931bcc test: migrate webapp auth service tests to testcontainers (#34037) 2026-03-25 13:33:01 +08:00
412eb15527 test: replace indexing_technique string literals with IndexTechnique (#34042) 2026-03-25 13:33:01 +08:00
d2deab61d1 fix: update docs path (#34052) 2026-03-25 13:33:01 +08:00
lif
430fb1ee97 fix(workflow): clear loop/iteration metadata when pasting node outside container (#29983)
Co-authored-by: hjlarry <hjlarry@163.com>
2026-03-25 13:33:01 +08:00
2b0b3e3321 test(workflow-app): enhance unit tests for workflow components and hooks
- Added tests to ensure proper handling of unavailable trigger data and empty payloads in WorkflowApp.
- Implemented tests for WorkflowChildren to verify behavior with various node configurations and event handling.
- Expanded useWorkflowRun and useWorkflowStartRun tests to cover additional scenarios, including input validation and trigger execution conditions.
- Improved test coverage for utility functions related to workflow run handling and debugging.
2026-03-25 13:25:00 +08:00
9f0fcdd049 feat(workflow-app): refactor workflow app components and add utility functions
- Introduced utility functions for building initial features and mapping trigger statuses, improving code organization and readability.
- Refactored the WorkflowApp component to utilize these new utilities, enhancing maintainability.
- Added comprehensive unit tests for the new utility functions to ensure correctness and reliability.
2026-03-25 13:06:22 +08:00
727fc057d3 [autofix.ci] apply automated fixes 2026-03-25 04:16:27 +00:00
3d10cf97f1 test: add unit tests for various workflow components
- Introduced new test files for CandidateNodeMain, CustomEdge, NodeContextmenu, PanelContextmenu, HelpLine, and several hooks.
- Each test file includes comprehensive tests to validate component rendering, interactions, and state management.
- Enhanced test coverage for dynamic test run options and data source configurations.
- Ensured proper mocking of dependencies to isolate component behavior during tests.
2026-03-25 12:12:59 +08:00
373f8245af Merge branch 'main' into test/workflow-part-8 2026-03-25 11:08:05 +08:00
168ba4caa3 test(tests): increase timeout limits for translation file tests and improve element selection in component tests
- Updated timeout expectations in translation file tests to allow for longer processing times, ensuring tests pass under varying conditions.
- Refined element selection in the GenericTable and ChatRecord integration tests to use more specific queries, enhancing test reliability and clarity.
2026-03-25 11:07:47 +08:00
688ccb5aa9 fix(tests): update variable reference picker tests with type assertions
- Added type assertions for iteration and loop nodes in the variable reference picker helper tests to ensure type safety and correctness.
- Improved code clarity by explicitly defining node types, enhancing maintainability of the test suite.
2026-03-24 20:16:09 +08:00
af143312f2 refactor(workflow): consolidate selection context menu helpers into component
- Moved selection context menu helper functions directly into the SelectionContextmenu component for improved encapsulation and organization.
- Removed the separate helpers file and associated tests, streamlining the codebase.
- Updated the component to maintain existing functionality while enhancing readability and maintainability.
2026-03-24 19:47:18 +08:00
3c58c68b8d feat(workflow): implement variable modal state management and helpers
- Introduced a new custom hook  to manage the state of the variable modal, encapsulating logic for handling variable types, values, and validation.
- Created helper functions for formatting and validating variable data, improving modularity and reusability.
- Added new components for rendering variable modal sections, enhancing the user interface and user experience.
- Implemented comprehensive unit tests for the new hook and helpers to ensure functionality and correctness.
2026-03-24 18:22:35 +08:00
20f901223b [autofix.ci] apply automated fixes 2026-03-24 10:03:16 +00:00
16e8bf1cf9 Merge remote-tracking branch 'origin/main' into test/workflow-part-8 2026-03-24 17:54:37 +08:00
1943785c1c feat(workflow): add selection context menu helpers and integrate with context menu component
- Introduced helper functions for managing alignment and distribution of nodes within the workflow.
- Created a new file for selection context menu helpers, encapsulating logic for menu positioning and node alignment.
- Updated the SelectionContextmenu component to utilize the new helpers, improving code organization and readability.
- Added unit tests for the new helper functions to ensure functionality and correctness.
2026-03-24 17:53:48 +08:00
6633f5aef8 feat(workflow): implement DSL modal helpers and refactor update DSL modal
- Added helper functions for DSL validation, import status handling, and feature normalization.
- Refactored the UpdateDSLModal component to utilize the new helper functions, improving readability and maintainability.
- Introduced unit tests for the new helper functions to ensure correctness and reliability.
- Enhanced the node components with new utility functions for managing node states and rendering logic.
2026-03-24 16:16:54 +08:00
52 changed files with 1173 additions and 3771 deletions

View File

@ -4,6 +4,7 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -179,7 +180,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account

View File

@ -607,15 +607,19 @@ class PublishedRagPipelineApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.publish_workflow(
session=db.session, # type: ignore[reportArgumentType,arg-type]
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
workflow = rag_pipeline_service.publish_workflow(
session=session,
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
session.add(pipeline)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",

View File

@ -16,14 +16,12 @@ api = ExternalApi(
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
from . import mail as _mail
from .app import dsl as _app_dsl
from .plugin import plugin as _plugin
from .workspace import workspace as _workspace
api.add_namespace(inner_api_ns)
__all__ = [
"_app_dsl",
"_mail",
"_plugin",
"_workspace",

View File

@ -1 +0,0 @@

View File

@ -1,110 +0,0 @@
"""Inner API endpoints for app DSL import/export.
Called by the enterprise admin-api service. Import requires ``creator_email``
to attribute the created app; workspace/membership validation is done by the
Go admin-api caller.
"""
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
from extensions.ext_database import db
from models import Account, App
from models.account import AccountStatus
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
class InnerAppDSLImportPayload(BaseModel):
yaml_content: str = Field(description="YAML DSL content")
creator_email: str = Field(description="Email of the workspace member who will own the imported app")
name: str | None = Field(default=None, description="Override app name from DSL")
description: str | None = Field(default=None, description="Override app description from DSL")
register_schema_model(inner_api_ns, InnerAppDSLImportPayload)
@inner_api_ns.route("/enterprise/workspaces/<string:workspace_id>/dsl/import")
class EnterpriseAppDSLImport(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc("enterprise_app_dsl_import")
@inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__])
@inner_api_ns.doc(
responses={
200: "Import completed",
202: "Import pending (DSL version mismatch requires confirmation)",
400: "Import failed (business error)",
404: "Creator account not found or inactive",
}
)
def post(self, workspace_id: str):
"""Import a DSL into a workspace on behalf of a specified creator."""
args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {})
account = _get_active_account(args.creator_email)
if account is None:
return {"message": f"account '{args.creator_email}' not found or inactive"}, 404
account.set_tenant_id(workspace_id)
with Session(db.engine) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=args.yaml_content,
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@inner_api_ns.route("/enterprise/apps/<string:app_id>/dsl")
class EnterpriseAppDSLExport(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc(
"enterprise_app_dsl_export",
responses={
200: "Export successful",
404: "App not found",
},
)
def get(self, app_id: str):
"""Export an app's DSL as YAML."""
include_secret = request.args.get("include_secret", "false").lower() == "true"
app_model = db.session.query(App).filter_by(id=app_id).first()
if not app_model:
return {"message": "app not found"}, 404
data = AppDslService.export_dsl(
app_model=app_model,
include_secret=include_secret,
)
return {"data": data}, 200
def _get_active_account(email: str) -> Account | None:
"""Look up an active account by email.
Workspace membership is already validated by the Go admin-api caller.
"""
account = db.session.query(Account).filter_by(email=email).first()
if account is None or account.status != AccountStatus.ACTIVE:
return None
return account

View File

@ -1,5 +1,4 @@
import json
import logging
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
@ -8,8 +7,6 @@ from typing import TYPE_CHECKING, Any, Union
import httpx
import qdrant_client
logger = logging.getLogger(__name__)
from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel
@ -291,27 +288,26 @@ class TidbOnQdrantVector(BaseVector):
if not ids:
return
batch_size = 1000
for i in range(0, len(ids), batch_size):
batch = ids[i : i + batch_size]
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=batch),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code != 404:
raise e
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
@ -420,16 +416,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id)
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id)
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
@ -440,18 +433,11 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
.one_or_none()
)
if idle_tidb_auth_binding:
logger.info(
"Assigning idle cluster %s to tenant %s",
idle_tidb_auth_binding.cluster_id,
dataset.tenant_id,
)
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
tidb_auth_binding = idle_tidb_auth_binding
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id)
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
dify_config.TIDB_API_URL or "",
@ -460,39 +446,21 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
dify_config.TIDB_PRIVATE_KEY or "",
dify_config.TIDB_REGION or "",
)
logger.info(
"New cluster created: cluster_id=%s, qdrant_endpoint=%s",
new_cluster["cluster_id"],
new_cluster.get("qdrant_endpoint"),
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
tenant_id=dataset.tenant_id,
active=True,
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
tidb_auth_binding = new_tidb_auth_binding
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id)
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
qdrant_url = (
(tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or ""
)
logger.info(
"Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)",
qdrant_url,
tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None,
dify_config.TIDB_ON_QDRANT_URL,
)
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -507,7 +475,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=qdrant_url,
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,

View File

@ -1,4 +1,3 @@
import logging
import time
import uuid
from collections.abc import Sequence
@ -12,50 +11,8 @@ from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
logger = logging.getLogger(__name__)
class TidbService:
@staticmethod
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
"""Extract the qdrant endpoint URL from a Get Cluster API response.
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
"""
endpoints = cluster_response.get("endpoints") or {}
public = endpoints.get("public") or {}
host = public.get("host")
if host:
return f"https://qdrant-{host}"
return None
@staticmethod
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
"""Call Get Cluster API and extract the qdrant endpoint.
Use ``extract_qdrant_endpoint`` instead when you already have
the cluster response to avoid a redundant API call.
"""
try:
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if not cluster_response:
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
return None
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
if qdrant_url:
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
return qdrant_url
logger.warning(
"No endpoints.public.host found for cluster %s, response keys: %s",
cluster_id,
list(cluster_response.keys()),
)
except Exception:
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
return None
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
@ -93,45 +50,26 @@ class TidbService:
"rootPassword": password,
}
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
logger.info(
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
cluster_id,
user_prefix,
qdrant_endpoint,
)
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
"qdrant_endpoint": qdrant_endpoint,
}
logger.info(
"Cluster %s state=%s, retry %d/%d",
cluster_id,
cluster_response["state"],
retry_count + 1,
max_retries,
)
time.sleep(30)
time.sleep(30) # wait 30 seconds before retrying
retry_count += 1
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
else:
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
@staticmethod
@ -233,20 +171,8 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
if not cluster_info.qdrant_endpoint:
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
item
) or TidbService.fetch_qdrant_endpoint(
api_url, public_key, private_key, item["clusterId"]
)
if cluster_info.qdrant_endpoint:
cluster_info.status = TidbAuthBindingStatus.ACTIVE
else:
logger.warning(
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
item["clusterId"],
)
db.session.add(cluster_info)
db.session.commit()
else:
@ -304,29 +230,19 @@ class TidbService:
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
cached_password = redis_client.get(cache_key)
if not cached_password:
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
continue
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
logger.info(
"Batch cluster %s: qdrant_endpoint=%s",
item["clusterId"],
qdrant_endpoint,
)
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": cached_password.decode("utf-8"),
"qdrant_endpoint": qdrant_endpoint,
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
return []

View File

@ -1,17 +1,56 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -19,3 +58,152 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -1,26 +0,0 @@
"""add qdrant_endpoint to tidb_auth_bindings
Revision ID: 8574b23a38fd
Revises: 6b5f9f8b1a2c
Create Date: 2026-04-14 15:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8574b23a38fd"
down_revision = "6b5f9f8b1a2c"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True))
def downgrade():
with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op:
batch_op.drop_column("qdrant_endpoint")

View File

@ -1250,7 +1250,6 @@ class TidbAuthBinding(TypeBase):
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -113,7 +113,6 @@ class DataSourceType(StrEnum):
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
ONLINE_DRIVE = "online_drive"
class ProcessRuleMode(StrEnum):

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.13.3"
version = "1.13.2"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -57,7 +57,6 @@ def create_clusters(batch_size):
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
qdrant_endpoint=new_cluster.get("qdrant_endpoint"),
active=False,
status=TidbAuthBindingStatus.CREATING,
)

View File

@ -4,7 +4,7 @@ import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -18,13 +18,12 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -89,7 +88,7 @@ class AppGenerateService:
def generate(
cls,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -107,7 +106,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -117,150 +116,139 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)
match effective_mode:
case AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
request_id=request_id,
)
case AppMode.AGENT_CHAT:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
request_id,
)
case AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
request_id=request_id,
)
case AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
)
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
)
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,
)
case AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
except Exception:
quota_charge.refund()
rate_limit.exit(request_id)
@ -292,83 +280,53 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_single_loop(
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
):
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(
cls,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping | Generator:
) -> Union[Mapping, Generator]:
"""
Generate more like this
:param app_model: app model

View File

@ -22,7 +22,6 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -89,10 +88,7 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@ -135,10 +131,9 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
# 7. Check and consume quota
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -158,18 +153,13 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
@ -305,21 +295,13 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns:
Workflow instance
@ -329,12 +311,12 @@ class AsyncWorkflowService:
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model, session=session)
workflow = workflow_service.get_published_workflow(app_model)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@ -2,11 +2,12 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal, NotRequired, TypedDict
from typing import Literal
import httpx
from pydantic import TypeAdapter
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
@ -25,147 +26,6 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class KnowledgeRateLimitDict(TypedDict):
limit: int
subscription_plan: str
class TenantFeaturePlanUsageDict(TypedDict):
result: str
history_id: str
class LangContentDict(TypedDict):
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationDict(TypedDict):
notification_id: str
contents: dict[str, LangContentDict]
frequency: Literal["once", "every_page_load"]
class AccountNotificationDict(TypedDict, total=False):
should_show: bool
notification: NotificationDict
shouldShow: bool
notifications: list[dict]
class UpsertNotificationDict(TypedDict):
notification_id: str
class BatchAddNotificationAccountsDict(TypedDict):
count: int
class DismissNotificationDict(TypedDict):
success: bool
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@ -178,73 +38,21 @@ class BillingService:
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str) -> BillingInfo:
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return _billing_info_adapter.validate_python(billing_info)
return billing_info
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

View File

@ -1,7 +1,7 @@
import logging
from sqlalchemy import select, update
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
@ -29,15 +29,14 @@ class CreditPoolService:
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
"""get tenant credit pool"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.scalar(
select(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.limit(1)
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_credits_available(

View File

@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
@ -312,10 +312,7 @@ class FeatureService:
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
# but LimitationModel.size is int; truncate here for compatibility
features.vector_space.size = int(billing_info["vector_space"]["size"])
# NOTE END
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
@ -336,11 +333,7 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
# NOTE (hj24):
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
# NOTE END
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]

View File

@ -1,233 +0,0 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -37,7 +37,6 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -759,47 +758,45 @@ class WebhookService:
Exception: If workflow execution fails
"""
try:
workflow_inputs = cls.build_workflow_inputs(webhook_data)
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id,
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
raise
try:
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
# trigger_workflow_async need to handle multipe session commits internally
with Session(db.engine, expire_on_commit=False) as session:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# consume quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -132,38 +132,31 @@ class WorkflowService:
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
# fetch draft workflow by app_model
workflow = db.session.scalar(
select(Workflow)
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.limit(1)
.first()
)
# return draft workflow
return workflow
def get_published_workflow_by_id(
self, app_model: App, workflow_id: str, session: Session | None = None
) -> Workflow | None:
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
"""
fetch published workflow by workflow_id
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
)
.limit(1)
.first()
)
if not workflow:
return None
@ -174,27 +167,23 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
def get_published_workflow(self, app_model: App) -> Workflow | None:
"""
Get published workflow
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
if not app_model.workflow_id:
return None
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
.limit(1)
.first()
)
return workflow

View File

@ -156,12 +156,7 @@ def _execute_workflow_common(
state_owner_user_id=workflow.created_by,
)
# NOTE (hj24)
# Release the transaction before the blocking generate() call,
# otherwise the connection stays "idle in transaction" for hours.
session.commit()
# NOTE END
# Execute the workflow with the trigger type
generator.generate(
app_model=app_model,
workflow=workflow,

View File

@ -28,7 +28,7 @@ from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from dify_graph.enums import WorkflowExecutionStatus
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -42,7 +42,6 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@ -259,58 +258,59 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
# Ensure expire_on_commit is set to False to remain workflows available
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
# Get workflow from mapping
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
return dispatched_count
# Find the trigger node in the workflow
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
with session_factory.create_session() as session:
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
)
return 0
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
@ -387,7 +387,6 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",
@ -402,7 +401,7 @@ def dispatch_triggered_workflow(
plugin_trigger.app_id,
)
return dispatched_count
return dispatched_count
def dispatch_triggered_workflows(

View File

@ -8,11 +8,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@ -33,7 +32,6 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
@ -43,16 +41,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
if not tenant_owner:
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
quota_charge = unlimited()
try:
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
try:
with session_factory.create_session() as session:
try:
# Production dispatch: Trigger the workflow normally
response = AsyncWorkflowService.trigger_workflow_async(
session=session,
user=tenant_owner,
@ -63,10 +61,9 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e

View File

@ -163,9 +163,11 @@ class DifyTestContainers:
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
logger.info("Redis container is ready and accepting connections")
# Start Dify Sandbox container for code execution environment.
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
# Use pinned version 0.2.12 to match production docker-compose configuration
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.14").with_network(self.network)
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",
@ -185,7 +187,7 @@ class DifyTestContainers:
# Start Dify Plugin Daemon container for plugin management
# Dify Plugin Daemon provides plugin lifecycle management and execution
logger.info("Initializing Dify Plugin Daemon container...")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network(
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network(
self.network
)
self.dify_plugin_daemon.with_exposed_ports(5002)

View File

@ -36,19 +36,12 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
}
# Setup default mock returns for workflow service
@ -108,8 +101,6 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@ -127,7 +118,6 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@ -475,7 +465,6 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@ -489,10 +478,8 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -1,517 +0,0 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from enums.quota_type import QuotaType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow
from services.errors.app import QuotaExceededError
from services.trigger.webhook_service import WebhookService
class WebhookServiceRelationshipFactory:
@staticmethod
def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]:
account = Account(
name=f"Account {uuid4()}",
email=f"webhook-{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App:
app = App(
tenant_id=tenant.id,
name=f"Webhook App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_workflow(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_ids: list[str],
version: str,
) -> Workflow:
graph = {
"nodes": [
{
"id": node_id,
"data": {
"type": TRIGGER_WEBHOOK_NODE_TYPE,
"title": f"Webhook {node_id}",
"method": "post",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"status_code": 200,
"response_body": '{"status": "ok"}',
"timeout": 30,
},
}
for node_id in node_ids
],
"edges": [],
}
workflow = Workflow(
tenant_id=app.tenant_id,
app_id=app.id,
type="workflow",
graph=json.dumps(graph),
features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
version=version,
)
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
@staticmethod
def create_webhook_trigger(
db_session_with_containers: Session,
*,
app: App,
account: Account,
node_id: str,
webhook_id: str | None = None,
) -> WorkflowWebhookTrigger:
webhook_trigger = WorkflowWebhookTrigger(
app_id=app.id,
node_id=node_id,
tenant_id=app.tenant_id,
webhook_id=webhook_id or uuid4().hex[:24],
created_by=account.id,
)
db_session_with_containers.add(webhook_trigger)
db_session_with_containers.commit()
return webhook_trigger
@staticmethod
def create_app_trigger(
db_session_with_containers: Session,
*,
app: App,
node_id: str,
status: AppTriggerStatus,
) -> AppTrigger:
app_trigger = AppTrigger(
tenant_id=app.tenant_id,
app_id=app.id,
node_id=node_id,
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
provider_name="webhook",
title=f"Webhook {node_id}",
status=status,
)
db_session_with_containers.add(app_trigger)
db_session_with_containers.commit()
return app_trigger
class TestWebhookServiceLookupWithContainers:
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with pytest.raises(ValueError, match="App trigger not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED
)
with pytest.raises(ValueError, match="rate limited"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED
)
with pytest.raises(ValueError, match="disabled"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
factory.create_app_trigger(
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
)
with pytest.raises(ValueError, match="Workflow not found"):
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["published-node"],
version="2026-04-14.001",
)
draft_workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["debug-node"],
version=Workflow.VERSION_DRAFT,
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="debug-node"
)
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
webhook_trigger.webhook_id,
is_debug=True,
)
assert got_trigger.id == webhook_trigger.id
assert got_workflow.id == draft_workflow.id
assert got_node_config["id"] == "debug-node"
class TestWebhookServiceTriggerExecutionWithContainers:
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
end_user = SimpleNamespace(id=str(uuid4()))
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
quota_charge = MagicMock()
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=end_user,
),
patch(
"services.trigger.webhook_service.QuotaService.reserve",
return_value=quota_charge,
) as mock_reserve,
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
):
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
mock_reserve.assert_called_once()
reserve_args = mock_reserve.call_args.args
assert reserve_args[0] == QuotaType.TRIGGER
assert reserve_args[1] == webhook_trigger.tenant_id
quota_charge.commit.assert_called_once()
mock_trigger.assert_called_once()
trigger_args = mock_trigger.call_args.args
assert trigger_args[1] is end_user
assert trigger_args[2].workflow_id == workflow.id
assert trigger_args[2].root_node_id == webhook_trigger.node_id
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=SimpleNamespace(id=str(uuid4())),
),
patch(
"services.trigger.webhook_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
),
patch(
"services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited"
) as mock_mark_rate_limited,
):
with pytest.raises(QuotaExceededError):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_mark_rate_limited.assert_called_once_with(tenant.id)
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
)
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
side_effect=RuntimeError("boom"),
),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"},
workflow,
)
mock_logger_exception.assert_called_once()
class TestWebhookServiceRelationshipSyncWithContainers:
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)]
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT
)
with pytest.raises(ValueError, match="maximum webhook node limit"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = False
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
stale_trigger = factory.create_webhook_trigger(
db_session_with_containers,
app=app,
account=account,
node_id="node-stale",
webhook_id="stale-webhook-id-000001",
)
stale_trigger_id = stale_trigger.id
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-new"],
version=Workflow.VERSION_DRAFT,
)
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001"
):
WebhookService.sync_webhook_relationships(app, workflow)
db_session_with_containers.expire_all()
records = db_session_with_containers.scalars(
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)
).all()
assert [record.node_id for record in records] == ["node-new"]
assert records[0].webhook_id == "new-webhook-id-000001"
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers,
app=app,
account=account,
node_ids=["node-cache"],
version=Workflow.VERSION_DRAFT,
)
cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache"
with patch(
"services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001"
):
WebhookService.sync_webhook_relationships(app, workflow)
cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key)
assert cached_payload is not None
assert cached_payload["node_id"] == "node-cache"
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
def test_sync_webhook_relationships_logs_when_lock_release_fails(
self, db_session_with_containers: Session, flask_app_with_containers
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
app = factory.create_app(db_session_with_containers, tenant, account)
workflow = factory.create_workflow(
db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT
)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
with (
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
WebhookService.sync_webhook_relationships(app, workflow)
mock_logger_exception.assert_called_once()
def _read_cache(cache_key: str) -> dict[str, str] | None:
from extensions.ext_redis import redis_client
cached = redis_client.get(cache_key)
if not cached:
return None
if isinstance(cached, bytes):
cached = cached.decode("utf-8")
return json.loads(cached)
WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache)

View File

@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from services import quota_service
from enums import quota_type
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

@ -1,245 +0,0 @@
"""Unit tests for inner_api app DSL import/export endpoints.
Tests Pydantic model validation, endpoint handler logic, and the
_get_active_account helper. Auth/setup decorators are tested separately
in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them.
"""
import inspect
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.app.dsl import (
EnterpriseAppDSLExport,
EnterpriseAppDSLImport,
InnerAppDSLImportPayload,
_get_active_account,
)
from services.app_dsl_service import ImportStatus
class TestInnerAppDSLImportPayload:
"""Test InnerAppDSLImportPayload Pydantic model validation."""
def test_valid_payload_all_fields(self):
data = {
"yaml_content": "version: 0.6.0\nkind: app\n",
"creator_email": "user@example.com",
"name": "My App",
"description": "A test app",
}
payload = InnerAppDSLImportPayload.model_validate(data)
assert payload.yaml_content == data["yaml_content"]
assert payload.creator_email == "user@example.com"
assert payload.name == "My App"
assert payload.description == "A test app"
def test_valid_payload_optional_fields_omitted(self):
data = {
"yaml_content": "version: 0.6.0\n",
"creator_email": "user@example.com",
}
payload = InnerAppDSLImportPayload.model_validate(data)
assert payload.name is None
assert payload.description is None
def test_missing_yaml_content_fails(self):
with pytest.raises(ValidationError) as exc_info:
InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"})
assert "yaml_content" in str(exc_info.value)
def test_missing_creator_email_fails(self):
with pytest.raises(ValidationError) as exc_info:
InnerAppDSLImportPayload.model_validate({"yaml_content": "test"})
assert "creator_email" in str(exc_info.value)
class TestGetActiveAccount:
"""Test the _get_active_account helper function."""
@patch("controllers.inner_api.app.dsl.db")
def test_returns_active_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "active"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
result = _get_active_account("user@example.com")
assert result is mock_account
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_inactive_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "banned"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
result = _get_active_account("banned@example.com")
assert result is None
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_nonexistent_email(self, mock_db):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
result = _get_active_account("missing@example.com")
assert result is None
class TestEnterpriseAppDSLImport:
"""Test EnterpriseAppDSLImport endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseAppDSLImport()
@pytest.fixture
def _mock_import_deps(self):
"""Patch db, Session, and AppDslService for import handler tests."""
with (
patch("controllers.inner_api.app.dsl.db"),
patch("controllers.inner_api.app.dsl.Session") as mock_session,
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
):
mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock())
mock_session.return_value.__exit__ = MagicMock(return_value=False)
self._mock_dsl = MagicMock()
mock_dsl_cls.return_value = self._mock_dsl
yield
def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import":
from services.app_dsl_service import Import
result = Import(
id="import-id",
status=status,
app_id=kwargs.get("app_id", "app-123"),
app_mode=kwargs.get("app_mode", "workflow"),
)
return result
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED)
unwrapped = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
mock_ns.payload = {
"yaml_content": "version: 0.6.0\n",
"creator_email": "user@example.com",
}
result = unwrapped(api_instance, workspace_id="ws-123")
body, status_code = result
assert status_code == 200
assert body["status"] == "completed"
mock_account.set_tenant_id.assert_called_once_with("ws-123")
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask):
mock_get_account.return_value = MagicMock()
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING)
unwrapped = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"}
body, status_code = unwrapped(api_instance, workspace_id="ws-123")
assert status_code == 202
assert body["status"] == "pending"
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask):
mock_get_account.return_value = MagicMock()
self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED)
unwrapped = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"}
body, status_code = unwrapped(api_instance, workspace_id="ws-123")
assert status_code == 400
assert body["status"] == "failed"
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):
mock_get_account.return_value = None
unwrapped = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns:
mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"}
result = unwrapped(api_instance, workspace_id="ws-123")
body, status_code = result
assert status_code == 404
assert "missing@e.com" in body["message"]
class TestEnterpriseAppDSLExport:
"""Test EnterpriseAppDSLExport endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseAppDSLExport()
@patch("controllers.inner_api.app.dsl.AppDslService")
@patch("controllers.inner_api.app.dsl.db")
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
unwrapped = inspect.unwrap(api_instance.get)
with app.test_request_context("?include_secret=false"):
result = unwrapped(api_instance, app_id="app-123")
body, status_code = result
assert status_code == 200
assert body["data"] == "version: 0.6.0\nkind: app\n"
mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False)
@patch("controllers.inner_api.app.dsl.AppDslService")
@patch("controllers.inner_api.app.dsl.db")
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "yaml-data"
unwrapped = inspect.unwrap(api_instance.get)
with app.test_request_context("?include_secret=true"):
result = unwrapped(api_instance, app_id="app-123")
body, status_code = result
assert status_code == 200
mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True)
@patch("controllers.inner_api.app.dsl.db")
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
unwrapped = inspect.unwrap(api_instance.get)
with app.test_request_context("?include_secret=false"):
result = unwrapped(api_instance, app_id="nonexistent")
body, status_code = result
assert status_code == 404
assert "app not found" in body["message"]

View File

@ -8,7 +8,6 @@ import core.app.apps.pipeline.pipeline_generator as module
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceProviderType
from models.enums import DataSourceType
class FakeRagPipelineGenerateEntity(SimpleNamespace):
@ -559,24 +558,6 @@ def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
assert document.doc_metadata
def test_build_document_supports_online_drive_datasource_type(generator):
document = generator._build_document(
tenant_id="tenant",
dataset_id="ds",
built_in_field_enabled=True,
datasource_type=DatasourceProviderType.ONLINE_DRIVE,
datasource_info={"id": "file-1", "bucket": "bucket-1", "name": "drive.pdf", "type": "file"},
created_from="rag-pipeline",
position=1,
account=_build_user(),
batch="batch",
document_form="text",
)
assert DataSourceType(document.data_source_type) == DataSourceType.ONLINE_DRIVE
assert document.name == "drive.pdf"
def test_build_document_invalid_datasource_type(generator):
with pytest.raises(ValueError):
generator._build_document(

View File

@ -115,12 +115,14 @@ class TestTidbOnQdrantVectorDeleteByIds:
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_exactly_1000(self, vector_instance):
"""Test deletion with exactly 1000 IDs triggers a single batch."""
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
@ -128,28 +130,11 @@ class TestTidbOnQdrantVectorDeleteByIds:
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_splits_into_batches(self, vector_instance):
"""Test deletion with >1000 IDs triggers multiple batched calls."""
ids = [f"doc_{i}" for i in range(2500)]
vector_instance.delete_by_ids(ids)
assert vector_instance._client.delete.call_count == 3
batches = []
for call in vector_instance._client.delete.call_args_list:
filter_selector = call[1]["points_selector"]
field_condition = filter_selector.filter.must[0]
batches.append(field_condition.match.any)
assert len(batches[0]) == 1000
assert len(batches[1]) == 1000
assert len(batches[2]) == 500
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
@ -173,57 +158,3 @@ class TestTidbOnQdrantVectorDeleteByIds:
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids
class TestInitVectorEndpointSelection:
"""Test that init_vector selects the correct qdrant endpoint.
We avoid importing the full module (which triggers Flask app context)
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
"""
def test_uses_binding_endpoint_when_present(self):
binding_endpoint = "https://qdrant-custom.tidb.com"
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-custom.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-custom.tidb.com"
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
binding_endpoint = None
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-global.tidb.com"
def test_falls_back_to_empty_when_both_none(self):
binding_endpoint = None
global_url = None
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == ""
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == ""
def test_binding_endpoint_takes_precedence_over_global(self):
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
global_url = "https://qdrant-us-east.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
def test_empty_string_binding_endpoint_falls_back_to_global(self):
binding_endpoint = ""
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"

View File

@ -1,308 +0,0 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from models.enums import TidbAuthBindingStatus
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
def test_returns_endpoint_when_host_present(self):
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
result = TidbService.extract_qdrant_endpoint(response)
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
def test_returns_none_when_host_missing(self):
response = {"endpoints": {"public": {}}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_public_missing(self):
response = {"endpoints": {}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_endpoints_missing(self):
assert TidbService.extract_qdrant_endpoint({}) is None
class TestFetchQdrantEndpoint:
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
mock_get_cluster.return_value = {
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
mock_get_cluster.return_value = None
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_host_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {}
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_on_exception(self, mock_get_cluster):
mock_get_cluster.side_effect = RuntimeError("network error")
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
class TestCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] is None
class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that batch_create includes qdrant_endpoint per cluster."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
cluster_name = "abc123"
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = b"password123"
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 1
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
class TestCreateTidbServerlessClusterRetry:
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.side_effect = [
{"state": "CREATING", "userPrefix": ""},
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
]
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
assert mock_get_cluster.call_count == 2
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is None
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=400, text="Bad Request")
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
mock_http.post.return_value = mock_response
with pytest.raises(Exception, match="HTTP 400"):
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
class TestBatchCreateEdgeCases:
"""Cover logging/edge-case branches in batch_create."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = None
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 0
mock_fetch_ep.assert_not_called()
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.redis_client")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=500, text="Server Error")
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
mock_http.post.return_value = mock_response
mock_redis.setex = MagicMock()
with pytest.raises(Exception, match="HTTP 500"):
TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
class TestBatchUpdateTidbServerlessClusterStatus:
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [
{
"clusterId": "c-1",
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
}
]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.db")
@patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_service.httpx")
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint is None
assert binding.status == TidbAuthBindingStatus.CREATING
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()

View File

@ -1,349 +0,0 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@ -23,7 +23,6 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@ -448,8 +447,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(
@ -468,8 +467,7 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
consume_mock.assert_called_once_with("tenant-id")
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@ -477,7 +475,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@ -494,7 +492,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(

View File

@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_service: QuotaService mock
- quota_workflow: QuotaType.WORKFLOW
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@ -72,7 +72,12 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
mock_quota_service = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
mock_team_task = MagicMock()
mock_sandbox_task = MagicMock()
with (
patch.object(
@ -88,8 +93,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaService",
new=mock_quota_service,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@ -102,7 +107,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_service": mock_quota_service,
"quota_workflow": quota_workflow,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@ -141,9 +146,6 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@ -161,9 +163,8 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
assert session.commit.call_count == 3
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
assert session.commit.call_count == 2
created_log = mocks["repo"].create.call_args[0][0]
assert created_log.status == WorkflowTriggerStatus.QUEUED
@ -249,7 +250,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,
@ -266,7 +267,7 @@ class TestAsyncWorkflowService:
trigger_data=trigger_data,
)
assert session.commit.call_count == 3
assert session.commit.call_count == 2
updated_log = mocks["repo"].update.call_args[0][0]
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
assert "Quota limit reached" in updated_log.error
@ -462,7 +463,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None)
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
workflow_service.get_published_workflow.assert_not_called()
def test_should_raise_when_specific_workflow_id_not_found(self):
@ -490,7 +491,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
workflow_service.get_published_workflow.assert_called_once_with(app_model)
workflow_service.get_published_workflow_by_id.assert_not_called()
def test_should_raise_when_default_published_workflow_not_found(self):

View File

@ -290,19 +290,9 @@ class TestBillingServiceSubscriptionInfo:
# Arrange
tenant_id = "tenant-123"
expected_response = {
"enabled": True,
"subscription": {"plan": "professional", "interval": "month", "education": False},
"members": {"size": 1, "limit": 50},
"apps": {"size": 1, "limit": 200},
"vector_space": {"size": 0.0, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 0, "limit": 1000},
"annotation_quota_limit": {"size": 0, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1775952000,
"subscription_plan": "professional",
"billing_cycle": "monthly",
"status": "active",
}
mock_send_request.return_value = expected_response
@ -425,7 +415,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
"""Test retrieval of tenant feature plan usage information."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@ -438,20 +428,6 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@ -529,150 +505,6 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
def test_get_quota_info_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int for get_quota_info."""
mock_send_request.return_value = {
"trigger_event": {"usage": "42", "limit": "3000", "reset_date": "1700000000"},
"api_rate_limit": {"usage": "10", "limit": "-1", "reset_date": "-1"},
}
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert isinstance(result["trigger_event"]["usage"], int)
assert result["trigger_event"]["limit"] == 3000
assert isinstance(result["trigger_event"]["limit"], int)
assert result["trigger_event"]["reset_date"] == 1700000000
assert isinstance(result["trigger_event"]["reset_date"], int)
assert result["api_rate_limit"]["limit"] == -1
assert isinstance(result["api_rate_limit"]["limit"], int)
def test_get_quota_info_accepts_int_values(self, mock_send_request):
"""Test that get_quota_info works with native int values."""
expected = {
"trigger_event": {"usage": 42, "limit": 3000, "reset_date": 1700000000},
"api_rate_limit": {"usage": 0, "limit": -1},
}
mock_send_request.return_value = expected
result = BillingService.get_quota_info("t1")
assert result["trigger_event"]["usage"] == 42
assert result["trigger_event"]["limit"] == 3000
assert result["api_rate_limit"]["limit"] == -1
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.
@ -1177,14 +1009,17 @@ class TestBillingServiceEdgeCases:
yield mock
def test_get_info_empty_response(self, mock_send_request):
"""Empty response from billing API should raise ValidationError due to missing required fields."""
from pydantic import ValidationError
"""Test handling of empty billing info response."""
# Arrange
tenant_id = "tenant-empty"
mock_send_request.return_value = {}
with pytest.raises(ValidationError):
BillingService.get_info(tenant_id)
# Act
result = BillingService.get_info(tenant_id)
# Assert
assert result == {}
mock_send_request.assert_called_once()
def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
"""Test updating tenant feature usage with zero delta (no change)."""
@ -1599,21 +1434,12 @@ class TestBillingServiceIntegrationScenarios:
# Step 1: Get current billing info
mock_send_request.return_value = {
"enabled": True,
"subscription": {"plan": "sandbox", "interval": "", "education": False},
"members": {"size": 0, "limit": 1},
"apps": {"size": 0, "limit": 5},
"vector_space": {"size": 0.0, "limit": 50},
"knowledge_rate_limit": {"limit": 10},
"documents_upload_quota": {"size": 0, "limit": 50},
"annotation_quota_limit": {"size": 0, "limit": 10},
"docs_processing": "standard",
"can_replace_logo": False,
"model_load_balancing_enabled": False,
"knowledge_pipeline_publish_enabled": False,
"subscription_plan": "sandbox",
"billing_cycle": "monthly",
"status": "active",
}
current_info = BillingService.get_info(tenant_id)
assert current_info["subscription"]["plan"] == "sandbox"
assert current_info["subscription_plan"] == "sandbox"
# Step 2: Get payment link for upgrade
mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
@ -1727,140 +1553,3 @@ class TestBillingServiceIntegrationScenarios:
mock_send_request.return_value = {"result": "success", "activated": True}
activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
assert activate_result["activated"] is True
class TestBillingServiceSubscriptionInfoDataType:
"""Unit tests for data type coercion in BillingService.get_info
1. Verifies the get_info returns correct Python types for numeric fields
2. Ensure the compatibility regardless of what results the upstream billing API returns
"""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
@pytest.fixture
def normal_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": 10, "limit": 50},
"apps": {"size": 80, "limit": 200},
"vector_space": {"size": 5120.75, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 450, "limit": 1000},
"annotation_quota_limit": {"size": 1200, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1745971200,
}
@pytest.fixture
def string_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": "10", "limit": "50"},
"apps": {"size": "80", "limit": "200"},
"vector_space": {"size": 5120.75, "limit": "20480"},
"knowledge_rate_limit": {"limit": "1000"},
"documents_upload_quota": {"size": "450", "limit": "1000"},
"annotation_quota_limit": {"size": "1200", "limit": "5000"},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": "1745971200",
}
@staticmethod
def _assert_billing_info_types(result: dict):
assert isinstance(result["enabled"], bool)
assert isinstance(result["subscription"]["plan"], str)
assert isinstance(result["subscription"]["interval"], str)
assert isinstance(result["subscription"]["education"], bool)
assert isinstance(result["members"]["size"], int)
assert isinstance(result["members"]["limit"], int)
assert isinstance(result["apps"]["size"], int)
assert isinstance(result["apps"]["limit"], int)
assert isinstance(result["vector_space"]["size"], float)
assert isinstance(result["vector_space"]["limit"], int)
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
assert isinstance(result["documents_upload_quota"]["size"], int)
assert isinstance(result["documents_upload_quota"]["limit"], int)
assert isinstance(result["annotation_quota_limit"]["size"], int)
assert isinstance(result["annotation_quota_limit"]["limit"], int)
assert isinstance(result["docs_processing"], str)
assert isinstance(result["can_replace_logo"], bool)
assert isinstance(result["model_load_balancing_enabled"], bool)
assert isinstance(result["knowledge_pipeline_publish_enabled"], bool)
if "next_credit_reset_date" in result:
assert isinstance(result["next_credit_reset_date"], int)
def test_get_info_with_normal_types(self, mock_send_request, normal_billing_response):
"""When the billing API returns native numeric types, get_info should preserve them."""
mock_send_request.return_value = normal_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_with_string_types(self, mock_send_request, string_billing_response):
"""When the billing API returns numeric values as strings, get_info should coerce them."""
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
"""NotRequired fields can be absent without raising."""
del string_billing_response["next_credit_reset_date"]
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
assert "next_credit_reset_date" not in result
self._assert_billing_info_types(result)
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
"""Undefined fields are silently stripped by validate_python."""
string_billing_response["new_feature"] = "something"
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
# extra fields are dropped by TypeAdapter on TypedDict
assert "new_feature" not in result
self._assert_billing_info_types(result)
def test_get_info_missing_required_field_raises(self, mock_send_request, string_billing_response):
"""Missing a required field should raise ValidationError."""
from pydantic import ValidationError
del string_billing_response["members"]
mock_send_request.return_value = string_billing_response
with pytest.raises(ValidationError):
BillingService.get_info("tenant-type-test")

View File

@ -337,7 +337,10 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock()
mock_db_session.session.scalar.return_value = mock_workflow
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app)
@ -347,7 +350,10 @@ class TestWorkflowService:
"""Test get_draft_workflow returns None when no draft exists."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
mock_db_session.session.scalar.return_value = None
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
result = workflow_service.get_draft_workflow(app)
@ -359,7 +365,10 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(version="v1")
mock_db_session.session.scalar.return_value = mock_workflow
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
result = workflow_service.get_draft_workflow(app, workflow_id=workflow_id)
@ -374,7 +383,10 @@ class TestWorkflowService:
workflow_id = "workflow-123"
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
mock_db_session.session.scalar.return_value = mock_workflow
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -393,7 +405,10 @@ class TestWorkflowService:
workflow_id=workflow_id, version=Workflow.VERSION_DRAFT
)
mock_db_session.session.scalar.return_value = mock_workflow
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
with pytest.raises(IsDraftWorkflowError):
workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -403,7 +418,10 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock()
workflow_id = "nonexistent-workflow"
mock_db_session.session.scalar.return_value = None
# Mock database query to return None
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
result = workflow_service.get_published_workflow_by_id(app, workflow_id)
@ -415,7 +433,10 @@ class TestWorkflowService:
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=workflow_id)
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
mock_db_session.session.scalar.return_value = mock_workflow
# Mock database query
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
result = workflow_service.get_published_workflow(app)
@ -444,7 +465,11 @@ class TestWorkflowService:
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
features = {"file_upload": {"enabled": False}}
mock_db_session.session.scalar.return_value = None
# Mock get_draft_workflow to return None (no existing draft)
# This simulates the first time a workflow is created for an app
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = None
with (
patch.object(workflow_service, "validate_features_structure"),
@ -481,7 +506,9 @@ class TestWorkflowService:
# Mock existing draft workflow
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash=unique_hash)
mock_db_session.session.scalar.return_value = mock_workflow
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
with (
patch.object(workflow_service, "validate_features_structure"),
@ -520,7 +547,9 @@ class TestWorkflowService:
# Mock existing draft workflow with different hash
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(unique_hash="old-hash")
mock_db_session.session.scalar.return_value = mock_workflow
mock_query = MagicMock()
mock_db_session.session.query.return_value = mock_query
mock_query.where.return_value.first.return_value = mock_workflow
with pytest.raises(WorkflowHashNotEqualError):
workflow_service.sync_draft_workflow(

View File

@ -1,204 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
import tasks.trigger_processing_tasks as trigger_processing_tasks_module
from services.errors.app import QuotaExceededError
from tasks.trigger_processing_tasks import dispatch_triggered_workflow
class TestDispatchTriggeredWorkflow:
"""Unit tests covering branch behaviours of ``dispatch_triggered_workflow``.
The covered branches are:
- workflow missing for ``plugin_trigger.app_id`` → log + ``continue``
- ``QuotaService.reserve`` raising ``QuotaExceededError`` →
``mark_tenant_triggers_rate_limited`` + early ``return``
- ``trigger_workflow_async`` succeeds →
``quota_charge.commit()`` + ``dispatched_count`` increments
"""
@pytest.fixture
def subscription(self):
sub = MagicMock()
sub.id = "subscription-123"
sub.tenant_id = "tenant-123"
sub.provider_id = "langgenius/test_plugin/test_plugin"
sub.endpoint_id = "endpoint-123"
sub.credentials = {}
sub.credential_type = "api_key"
return sub
@pytest.fixture
def plugin_trigger(self):
trigger = MagicMock()
trigger.id = "plugin-trigger-123"
trigger.app_id = "app-123"
trigger.node_id = "node-123"
return trigger
@pytest.fixture
def provider_controller(self):
controller = MagicMock()
controller.plugin_unique_identifier = "langgenius/test_plugin:0.0.1"
controller.entity.identity.name = "Test Plugin"
controller.entity.identity.icon = "icon.svg"
controller.entity.identity.icon_dark = "icon_dark.svg"
return controller
@pytest.fixture
def dispatch_mocks(self, subscription, plugin_trigger, provider_controller):
"""Patch all external dependencies reached by ``dispatch_triggered_workflow``.
Defaults are configured so the code flow can reach the final async
trigger block (line ~385); each test overrides specific handles
(``get_workflows``, ``reserve``, ``create_end_user_batch``, ...) to
drive the path it targets.
"""
session_cm = MagicMock()
session_cm.__enter__.return_value = MagicMock()
session_cm.__exit__.return_value = False
invoke_response = MagicMock()
invoke_response.cancelled = False
invoke_response.variables = {}
quota_charge = MagicMock()
with (
patch.object(
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
"get_request",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module.TriggerHttpRequestCachingService,
"get_payload",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module.TriggerSubscriptionOperatorService,
"get_subscriber_triggers",
return_value=[plugin_trigger],
),
patch.object(
trigger_processing_tasks_module.TriggerManager,
"get_trigger_provider",
return_value=provider_controller,
),
patch.object(
trigger_processing_tasks_module.TriggerManager,
"invoke_trigger_event",
return_value=invoke_response,
) as invoke_trigger_event,
patch.object(
trigger_processing_tasks_module.TriggerEventNodeData,
"model_validate",
return_value=MagicMock(),
),
patch.object(
trigger_processing_tasks_module,
"_get_latest_workflows_by_app_ids",
) as get_workflows,
patch.object(
trigger_processing_tasks_module.EndUserService,
"create_end_user_batch",
return_value={},
) as create_end_user_batch,
patch.object(
trigger_processing_tasks_module.session_factory,
"create_session",
return_value=session_cm,
),
patch.object(
trigger_processing_tasks_module.QuotaService,
"reserve",
return_value=quota_charge,
) as reserve,
patch.object(
trigger_processing_tasks_module.AppTriggerService,
"mark_tenant_triggers_rate_limited",
) as mark_rate_limited,
patch.object(
trigger_processing_tasks_module.AsyncWorkflowService,
"trigger_workflow_async",
) as trigger_workflow_async,
):
yield {
"get_workflows": get_workflows,
"reserve": reserve,
"quota_charge": quota_charge,
"mark_rate_limited": mark_rate_limited,
"invoke_trigger_event": invoke_trigger_event,
"invoke_response": invoke_response,
"create_end_user_batch": create_end_user_batch,
"trigger_workflow_async": trigger_workflow_async,
}
def test_dispatch_skips_when_workflow_missing(self, subscription, dispatch_mocks):
"""Covers missing workflow → log + ``continue``."""
dispatch_mocks["get_workflows"].return_value = {}
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 0
dispatch_mocks["reserve"].assert_not_called()
dispatch_mocks["invoke_trigger_event"].assert_not_called()
dispatch_mocks["mark_rate_limited"].assert_not_called()
def test_dispatch_marks_rate_limited_when_quota_exceeded(self, subscription, plugin_trigger, dispatch_mocks):
"""Covers QuotaExceededError → mark rate-limited + early return."""
workflow_mock = MagicMock()
workflow_mock.walk_nodes.return_value = iter(
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
)
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
dispatch_mocks["reserve"].side_effect = QuotaExceededError(
feature="trigger", tenant_id=subscription.tenant_id, required=1
)
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 0
dispatch_mocks["reserve"].assert_called_once()
dispatch_mocks["mark_rate_limited"].assert_called_once_with(subscription.tenant_id)
dispatch_mocks["invoke_trigger_event"].assert_not_called()
def test_dispatch_commits_quota_and_counts_when_workflow_triggered(
self, subscription, plugin_trigger, dispatch_mocks
):
"""Happy path: end user exists and async trigger succeeds."""
workflow_mock = MagicMock()
workflow_mock.id = "workflow-123"
workflow_mock.walk_nodes.return_value = iter(
[(plugin_trigger.node_id, {"type": trigger_processing_tasks_module.TRIGGER_PLUGIN_NODE_TYPE})]
)
dispatch_mocks["get_workflows"].return_value = {plugin_trigger.app_id: workflow_mock}
end_user_mock = MagicMock()
dispatch_mocks["create_end_user_batch"].return_value = {plugin_trigger.app_id: end_user_mock}
dispatched = dispatch_triggered_workflow(
user_id="user-123",
subscription=subscription,
event_name="test_event",
request_id="request-123",
)
assert dispatched == 1
dispatch_mocks["trigger_workflow_async"].assert_called_once()
_, kwargs = dispatch_mocks["trigger_workflow_async"].call_args
assert kwargs["user"] is end_user_mock
dispatch_mocks["quota_charge"].commit.assert_called_once()
dispatch_mocks["quota_charge"].refund.assert_not_called()
dispatch_mocks["mark_rate_limited"].assert_not_called()

2
api/uv.lock generated
View File

@ -1457,7 +1457,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.13.3"
version = "1.13.2"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.13.3
image: langgenius/dify-web:1.13.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -245,7 +245,7 @@ services:
# The DifySandbox
sandbox:
image: langgenius/dify-sandbox:0.2.14
image: langgenius/dify-sandbox:0.2.12
restart: always
environment:
# The DifySandbox configurations
@ -269,7 +269,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.3-local
image: langgenius/dify-plugin-daemon:0.5.4-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -97,7 +97,7 @@ services:
# The DifySandbox
sandbox:
image: langgenius/dify-sandbox:0.2.14
image: langgenius/dify-sandbox:0.2.12
restart: always
env_file:
- ./middleware.env
@ -123,7 +123,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.3-local
image: langgenius/dify-plugin-daemon:0.5.4-local
restart: always
env_file:
- ./middleware.env

View File

@ -731,7 +731,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -773,7 +773,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -812,7 +812,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.13.3
image: langgenius/dify-api:1.13.2
restart: always
environment:
# Use the shared environment variables.
@ -842,7 +842,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.13.3
image: langgenius/dify-web:1.13.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -955,7 +955,7 @@ services:
# The DifySandbox
sandbox:
image: langgenius/dify-sandbox:0.2.14
image: langgenius/dify-sandbox:0.2.12
restart: always
environment:
# The DifySandbox configurations
@ -979,7 +979,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.3-local
image: langgenius/dify-plugin-daemon:0.5.4-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -5,8 +5,7 @@ app:
max_workers: 4
max_requests: 50
worker_timeout: 5
python_path: /opt/python/bin/python3
nodejs_path: /usr/local/bin/node
python_path: /usr/local/bin/python3
enable_network: True # please make sure there is no network risk in your environment
allowed_syscalls: # please leave it empty if you have no idea how seccomp works
proxy:

View File

@ -5,7 +5,7 @@ app:
max_workers: 4
max_requests: 50
worker_timeout: 5
python_path: /opt/python/bin/python3
python_path: /usr/local/bin/python3
python_lib_path:
- /usr/local/lib/python3.10
- /usr/lib/python3.10

View File

@ -501,16 +501,6 @@ describe('Question component', () => {
expect(onRegenerate).toHaveBeenCalled()
})
it('should render default question avatar icon when questionIcon is not provided', () => {
const { container } = renderWithProvider(
makeItem(),
vi.fn() as unknown as OnRegenerate,
)
const defaultIcon = container.querySelector('.question-default-user-icon')
expect(defaultIcon).toBeInTheDocument()
})
it('should render custom questionIcon when provided', () => {
const { container } = renderWithProvider(
makeItem(),
@ -519,7 +509,7 @@ describe('Question component', () => {
)
expect(screen.getByTestId('custom-question-icon')).toBeInTheDocument()
const defaultIcon = container.querySelector('.question-default-user-icon')
const defaultIcon = container.querySelector('.i-custom-public-avatar-user')
expect(defaultIcon).not.toBeInTheDocument()
})

View File

@ -15,7 +15,6 @@ import {
import { useTranslation } from 'react-i18next'
import Textarea from 'react-textarea-autosize'
import { FileList } from '@/app/components/base/file-uploader'
import { User } from '@/app/components/base/icons/src/public/avatar'
import { Markdown } from '@/app/components/base/markdown'
import { cn } from '@/utils/classnames'
import ActionButton from '../../action-button'
@ -244,7 +243,7 @@ const Question: FC<QuestionProps> = ({
{
questionIcon || (
<div className="h-full w-full rounded-full border-[0.5px] border-black/5">
<User className="question-default-user-icon h-full w-full" />
<div className="i-custom-public-avatar-user h-full w-full" />
</div>
)
}

View File

@ -142,7 +142,7 @@ const ApiKeyModal = ({
onExtraButtonClick={onRemove}
disabled={disabled || isLoading || doingAction}
clickOutsideNotClose={true}
wrapperClassName="!z-[1002]"
wrapperClassName="!z-[101]"
>
{pluginPayload.detail && (
<ReadmeEntrance pluginDetail={pluginPayload.detail} showType={ReadmeShowType.modal} />

View File

@ -157,7 +157,7 @@ const OAuthClientSettings = ({
)
}
containerClassName="pt-0"
wrapperClassName="!z-[1002]"
wrapperClassName="!z-[101]"
clickOutsideNotClose={true}
>
{pluginPayload.detail && (

View File

@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.13.3",
"version": "1.13.2",
"private": true,
"packageManager": "pnpm@10.32.1",
"imports": {
@ -125,15 +125,15 @@
"mime": "4.1.0",
"mitt": "3.0.1",
"negotiator": "1.0.0",
"next": "16.2.3",
"next": "16.2.1",
"next-themes": "0.4.6",
"nuqs": "2.8.9",
"pinyin-pro": "3.28.0",
"qrcode.react": "4.2.0",
"qs": "6.15.0",
"react": "19.2.5",
"react": "19.2.4",
"react-18-input-autosize": "3.0.0",
"react-dom": "19.2.5",
"react-dom": "19.2.4",
"react-easy-crop": "5.5.6",
"react-hotkeys-hook": "5.2.4",
"react-i18next": "16.6.1",
@ -173,8 +173,8 @@
"@mdx-js/loader": "3.1.1",
"@mdx-js/react": "3.1.1",
"@mdx-js/rollup": "3.1.1",
"@next/eslint-plugin-next": "16.2.3",
"@next/mdx": "16.2.3",
"@next/eslint-plugin-next": "16.2.1",
"@next/mdx": "16.2.1",
"@rgrove/parse-xml": "4.2.0",
"@storybook/addon-docs": "10.3.1",
"@storybook/addon-links": "10.3.1",
@ -231,7 +231,7 @@
"nock": "14.0.11",
"postcss": "8.5.8",
"postcss-js": "5.1.0",
"react-server-dom-webpack": "19.2.5",
"react-server-dom-webpack": "19.2.4",
"sass": "1.98.0",
"storybook": "10.3.1",
"tailwindcss": "3.4.19",

1066
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff