From 95d1913f2c2fefffe9def3bbccdec31dad792072 Mon Sep 17 00:00:00 2001 From: L1nSn0w Date: Sat, 14 Feb 2026 13:45:01 +0800 Subject: [PATCH] feat(api): add timeout and error handling options to enterprise request Enhanced the BaseRequest class to include optional timeout and raise_for_status parameters for improved request handling. Updated the EnterpriseService to utilize these new options during account addition to the default workspace, ensuring better control over request behavior. Additionally, modified unit tests to reflect these changes. --- api/services/enterprise/base.py | 7 ++++++- api/services/enterprise/enterprise_service.py | 10 +++++++++- .../services/enterprise/test_enterprise_service.py | 2 ++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index e3832475aa..8f7e6ed525 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -39,6 +39,9 @@ class BaseRequest: endpoint: str, json: Any | None = None, params: Mapping[str, Any] | None = None, + *, + timeout: float | httpx.Timeout | None = None, + raise_for_status: bool = False, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" @@ -53,7 +56,9 @@ class BaseRequest: logger.debug("Failed to generate traceparent header", exc_info=True) with httpx.Client(mounts=mounts) as client: - response = client.request(method, url, json=json, params=params, headers=headers) + response = client.request(method, url, json=json, params=params, headers=headers, timeout=timeout) + if raise_for_status: + response.raise_for_status() return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 632784ad20..dae1eeb1d6 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -9,6 +9,8 @@ from services.enterprise.base import EnterpriseRequest logger = logging.getLogger(__name__) +DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 + class WebAppSettings(BaseModel): access_mode: str = Field( @@ -97,7 +99,13 @@ class EnterpriseService: # Ensure we are sending a UUID-shaped string (enterprise side validates too). uuid.UUID(account_id) - data = EnterpriseRequest.send_request("POST", "/default-workspace/members", json={"account_id": account_id}) + data = EnterpriseRequest.send_request( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, + raise_for_status=True, + ) if not isinstance(data, dict): raise ValueError("Invalid response format from enterprise default workspace API") return DefaultWorkspaceJoinResult.model_validate(data) diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py index b4201aa061..2f7905fb06 100644 --- a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -36,6 +36,8 @@ class TestJoinDefaultWorkspace: "POST", "/default-workspace/members", json={"account_id": account_id}, + timeout=1.0, + raise_for_status=True, ) def test_join_default_workspace_invalid_response_format_raises(self):