diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index e3832475aa..8f7e6ed525 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -39,6 +39,9 @@ class BaseRequest: endpoint: str, json: Any | None = None, params: Mapping[str, Any] | None = None, + *, + timeout: float | httpx.Timeout | None = None, + raise_for_status: bool = False, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" @@ -53,7 +56,9 @@ class BaseRequest: logger.debug("Failed to generate traceparent header", exc_info=True) with httpx.Client(mounts=mounts) as client: - response = client.request(method, url, json=json, params=params, headers=headers) + response = client.request(method, url, json=json, params=params, headers=headers, timeout=timeout) + if raise_for_status: + response.raise_for_status() return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 632784ad20..dae1eeb1d6 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -9,6 +9,8 @@ from services.enterprise.base import EnterpriseRequest logger = logging.getLogger(__name__) +DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 + class WebAppSettings(BaseModel): access_mode: str = Field( @@ -97,7 +99,13 @@ class EnterpriseService: # Ensure we are sending a UUID-shaped string (enterprise side validates too). uuid.UUID(account_id) - data = EnterpriseRequest.send_request("POST", "/default-workspace/members", json={"account_id": account_id}) + data = EnterpriseRequest.send_request( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, + raise_for_status=True, + ) if not isinstance(data, dict): raise ValueError("Invalid response format from enterprise default workspace API") return DefaultWorkspaceJoinResult.model_validate(data) diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py index b4201aa061..2f7905fb06 100644 --- a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -36,6 +36,8 @@ class TestJoinDefaultWorkspace: "POST", "/default-workspace/members", json={"account_id": account_id}, + timeout=1.0, + raise_for_status=True, ) def test_join_default_workspace_invalid_response_format_raises(self):