Merge branch 'feat/tidb-endpoint' into deploy/dev

This commit is contained in:
Yansong Zhang
2026-04-15 14:32:35 +08:00
4 changed files with 220 additions and 373 deletions

View File

@ -1,324 +0,0 @@
import logging
import time
import uuid
from collections.abc import Sequence
import httpx
from httpx import DigestAuth
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
logger = logging.getLogger(__name__)
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
_tidb_http_client: httpx.Client = get_pooled_http_client(
"tidb:cloud",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
)
class TidbService:
@staticmethod
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
"""Fetch the qdrant endpoint for a cluster by calling the Get Cluster API.
The v1beta1 serverless Get Cluster response contains
``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``).
We prepend ``qdrant-`` and wrap it as an ``https://`` URL.
"""
try:
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if not cluster_response:
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
return None
# v1beta1 serverless: endpoints.public.host
endpoints = cluster_response.get("endpoints") or {}
public = endpoints.get("public") or {}
host = public.get("host")
if host:
qdrant_url = f"https://qdrant-{host}"
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
return qdrant_url
logger.warning(
"No endpoints.public.host found for cluster %s, response keys: %s",
cluster_id,
list(cluster_response.keys()),
)
except Exception:
logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id)
return None
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
):
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]
cluster_data = {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region)
response = _tidb_http_client.post(
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id)
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, cluster_id)
logger.info(
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
cluster_id,
user_prefix,
qdrant_endpoint,
)
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
"qdrant_endpoint": qdrant_endpoint,
}
logger.info(
"Cluster %s state=%s, retry %d/%d",
cluster_id,
cluster_response["state"],
retry_count + 1,
max_retries,
)
time.sleep(30)
retry_count += 1
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
else:
logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
@staticmethod
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = _tidb_http_client.delete(
f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def change_tidb_serverless_root_password(
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
:param account: The account for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = _tidb_http_client.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=DigestAuth(public_key, private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,
public_key: str,
private_key: str,
):
"""
Update the status of a new TiDB Serverless cluster.
:param tidb_serverless_list: The TiDB serverless list (required).
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:return: The response from the API.
"""
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = _tidb_http_client.get(
f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()
else:
response.raise_for_status()
@staticmethod
def batch_create_tidb_serverless_cluster(
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
) -> list[dict]:
"""
Creates a new TiDB Serverless cluster.
:param batch_size: The batch size (required).
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
for _ in range(batch_size):
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")
cluster_data = {
"cluster": {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
}
cache_key = f"tidb_serverless_cluster_password:{display_name}"
redis_client.setex(cache_key, 3600, password)
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = _tidb_http_client.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
logger.info("Batch created %d clusters", len(response_data.get("clusters", [])))
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
cached_password = redis_client.get(cache_key)
if not cached_password:
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
continue
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
logger.info(
"Batch cluster %s: qdrant_endpoint=%s",
item["clusterId"],
qdrant_endpoint,
)
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": cached_password.decode("utf-8"),
"qdrant_endpoint": qdrant_endpoint,
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text)
response.raise_for_status()
return []

View File

@ -24,12 +24,25 @@ _tidb_http_client: httpx.Client = get_pooled_http_client(
class TidbService:
@staticmethod
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
"""Fetch the qdrant endpoint for a cluster by calling the Get Cluster API.
def extract_qdrant_endpoint(cluster_response: dict) -> str | None:
"""Extract the qdrant endpoint URL from a Get Cluster API response.
The v1beta1 serverless Get Cluster response contains
``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``).
We prepend ``qdrant-`` and wrap it as an ``https://`` URL.
Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``),
prepends ``qdrant-`` and wraps it as an ``https://`` URL.
"""
endpoints = cluster_response.get("endpoints") or {}
public = endpoints.get("public") or {}
host = public.get("host")
if host:
return f"https://qdrant-{host}"
return None
@staticmethod
def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None:
"""Call Get Cluster API and extract the qdrant endpoint.
Use ``extract_qdrant_endpoint`` instead when you already have
the cluster response to avoid a redundant API call.
"""
try:
logger.info("Fetching qdrant endpoint for cluster %s", cluster_id)
@ -37,12 +50,8 @@ class TidbService:
if not cluster_response:
logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id)
return None
# v1beta1 serverless: endpoints.public.host
endpoints = cluster_response.get("endpoints") or {}
public = endpoints.get("public") or {}
host = public.get("host")
if host:
qdrant_url = f"https://qdrant-{host}"
qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response)
if qdrant_url:
logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url)
return qdrant_url
logger.warning(
@ -106,10 +115,12 @@ class TidbService:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, cluster_id)
qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response)
logger.info(
"Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s",
cluster_id, user_prefix, qdrant_endpoint,
cluster_id,
user_prefix,
qdrant_endpoint,
)
return {
"cluster_id": cluster_id,
@ -118,7 +129,13 @@ class TidbService:
"password": password,
"qdrant_endpoint": qdrant_endpoint,
}
logger.info("Cluster %s state=%s, retry %d/%d", cluster_id, cluster_response["state"], retry_count + 1, max_retries)
logger.info(
"Cluster %s state=%s, retry %d/%d",
cluster_id,
cluster_response["state"],
retry_count + 1,
max_retries,
)
time.sleep(30)
retry_count += 1
logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries)
@ -295,12 +312,11 @@ class TidbService:
if not cached_password:
logger.warning("No cached password for cluster %s, skipping", item["displayName"])
continue
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(
api_url, public_key, private_key, item["clusterId"]
)
qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
logger.info(
"Batch cluster %s: qdrant_endpoint=%s",
item["clusterId"], qdrant_endpoint,
item["clusterId"],
qdrant_endpoint,
)
cluster_info = {
"cluster_id": item["clusterId"],

View File

@ -172,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds:
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids
class TestInitVectorEndpointSelection:
"""Test that init_vector selects the correct qdrant endpoint.
We avoid importing the full module (which triggers Flask app context)
by testing the endpoint selection logic directly on TidbOnQdrantConfig.
"""
def test_uses_binding_endpoint_when_present(self):
binding_endpoint = "https://qdrant-custom.tidb.com"
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-custom.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-custom.tidb.com"
def test_falls_back_to_global_when_binding_endpoint_is_none(self):
binding_endpoint = None
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == "https://qdrant-global.tidb.com"
def test_falls_back_to_empty_when_both_none(self):
binding_endpoint = None
global_url = None
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == ""
config = TidbOnQdrantConfig(endpoint=qdrant_url)
assert config.endpoint == ""
def test_binding_endpoint_takes_precedence_over_global(self):
binding_endpoint = "https://qdrant-ap-southeast.tidb.com"
global_url = "https://qdrant-us-east.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-ap-southeast.tidb.com"
def test_empty_string_binding_endpoint_falls_back_to_global(self):
binding_endpoint = ""
global_url = "https://qdrant-global.tidb.com"
qdrant_url = binding_endpoint or global_url or ""
assert qdrant_url == "https://qdrant-global.tidb.com"

View File

@ -1,15 +1,36 @@
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
def test_returns_endpoint_when_host_present(self):
response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}}
result = TidbService.extract_qdrant_endpoint(response)
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
def test_returns_none_when_host_missing(self):
response = {"endpoints": {"public": {}}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_public_missing(self):
response = {"endpoints": {}}
assert TidbService.extract_qdrant_endpoint(response) is None
def test_returns_none_when_endpoints_missing(self):
assert TidbService.extract_qdrant_endpoint({}) is None
class TestFetchQdrantEndpoint:
"""Unit tests for TidbService.fetch_qdrant_endpoint."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_endpoint_when_host_present(self, mock_get_cluster):
mock_get_cluster.return_value = {
"status": {"connection_strings": {"standard": {"host": "gateway01.us-east-1.tidbcloud.com"}}}
"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}
}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com"
@ -17,65 +38,48 @@ class TestFetchQdrantEndpoint:
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster):
mock_get_cluster.return_value = None
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_host_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {"connection_strings": {"standard": {}}}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
mock_get_cluster.return_value = {"endpoints": {"public": {}}}
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_status_missing(self, mock_get_cluster):
def test_returns_none_when_endpoints_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_connection_strings_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_on_exception(self, mock_get_cluster):
mock_get_cluster.side_effect = RuntimeError("network error")
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
@patch.object(TidbService, "get_tidb_serverless_cluster")
def test_returns_none_when_standard_key_missing(self, mock_get_cluster):
mock_get_cluster.return_value = {"status": {"connection_strings": {}}}
result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123")
assert result is None
assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None
class TestCreateTidbServerlessClusterQdrantEndpoint:
"""Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep):
def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
mock_get_cluster.return_value = {
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}},
}
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
mock_fetch_ep.assert_called_once_with("url", "pub", "priv", "c-1")
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_result_qdrant_endpoint_none_when_fetch_fails(
self, mock_config, mock_http, mock_get_cluster, mock_fetch_ep
):
def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"}
@ -115,3 +119,100 @@ class TestBatchCreateTidbServerlessClusterQdrantEndpoint:
assert len(result) == 1
assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com"
class TestCreateTidbServerlessClusterRetry:
"""Cover retry/logging paths in create_tidb_serverless_cluster."""
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.side_effect = [
{"state": "CREATING", "userPrefix": ""},
{"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}},
]
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is not None
assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com"
assert mock_get_cluster.call_count == 2
@patch.object(TidbService, "get_tidb_serverless_cluster")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"})
mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""}
with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"):
result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
assert result is None
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=400, text="Bad Request")
mock_response.raise_for_status.side_effect = Exception("HTTP 400")
mock_http.post.return_value = mock_response
with pytest.raises(Exception, match="HTTP 400"):
TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1")
class TestBatchCreateEdgeCases:
"""Cover logging/edge-case branches in batch_create."""
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep):
mock_config.TIDB_SPEND_LIMIT = 10
mock_http.post.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]},
)
mock_redis.setex = MagicMock()
mock_redis.get.return_value = None
result = TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)
assert len(result) == 0
mock_fetch_ep.assert_not_called()
@patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config")
def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis):
mock_config.TIDB_SPEND_LIMIT = 10
mock_response = MagicMock(status_code=500, text="Server Error")
mock_response.raise_for_status.side_effect = Exception("HTTP 500")
mock_http.post.return_value = mock_response
mock_redis.setex = MagicMock()
with pytest.raises(Exception, match="HTTP 500"):
TidbService.batch_create_tidb_serverless_cluster(
batch_size=1,
project_id="proj",
api_url="url",
iam_url="iam",
public_key="pub",
private_key="priv",
region="us-east-1",
)