From 2b739b95448367baefbfa728cb09c18ba2bed6c2 Mon Sep 17 00:00:00 2001 From: GareArc Date: Wed, 4 Mar 2026 19:53:43 -0800 Subject: [PATCH] fix: handle enterprise API errors properly to prevent KeyError crashes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When enterprise API returns 403/404, the response contains error JSON instead of expected data structure. Code was accessing fields directly causing KeyError → 500 Internal Server Error. Changes: - Add enterprise-specific error classes (EnterpriseAPIError, etc.) - Implement centralized error validation in EnterpriseRequest.send_request() - Extract error messages from API responses (message/error/detail fields) - Raise domain-specific errors based on HTTP status codes - Preserve backward compatibility with raise_for_status parameter This prevents KeyError crashes and returns proper HTTP error codes (403/404) instead of 500 errors. --- api/services/enterprise/base.py | 56 +++++++++++++++++++++++++++++++ api/services/errors/__init__.py | 2 ++ api/services/errors/enterprise.py | 45 +++++++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 api/services/errors/enterprise.py diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 744b7992f8..86cca34cf2 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -6,6 +6,13 @@ from typing import Any import httpx from core.helper.trace_id_helper import generate_traceparent_header +from services.errors.enterprise import ( + EnterpriseAPIBadRequestError, + EnterpriseAPIError, + EnterpriseAPIForbiddenError, + EnterpriseAPINotFoundError, + EnterpriseAPIUnauthorizedError, +) logger = logging.getLogger(__name__) @@ -64,10 +71,59 @@ class BaseRequest: request_kwargs["timeout"] = timeout response = client.request(method, url, **request_kwargs) + + # Always validate HTTP status and raise domain-specific errors + if not response.is_success: + cls._handle_error_response(response) + + # Legacy support: still respect raise_for_status parameter if raise_for_status: response.raise_for_status() + return response.json() + @classmethod + def _handle_error_response(cls, response: httpx.Response) -> None: + """ + Handle non-2xx HTTP responses by raising appropriate domain errors. + + Attempts to extract error message from JSON response body, + falls back to status text if parsing fails. + """ + error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}" + + # Try to extract error message from JSON response + try: + error_data = response.json() + if isinstance(error_data, dict): + # Common error response formats: + # {"error": "...", "message": "..."} + # {"message": "..."} + # {"detail": "..."} + error_message = ( + error_data.get("message") + or error_data.get("error") + or error_data.get("detail") + or error_message + ) + except Exception: + # If JSON parsing fails, use the default message + logger.debug( + "Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True + ) + + # Raise specific error based on status code + if response.status_code == 400: + raise EnterpriseAPIBadRequestError(error_message) + elif response.status_code == 401: + raise EnterpriseAPIUnauthorizedError(error_message) + elif response.status_code == 403: + raise EnterpriseAPIForbiddenError(error_message) + elif response.status_code == 404: + raise EnterpriseAPINotFoundError(error_message) + else: + raise EnterpriseAPIError(error_message, status_code=response.status_code) + class EnterpriseRequest(BaseRequest): base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 697e691224..15f004463d 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -7,6 +7,7 @@ from . import ( conversation, dataset, document, + enterprise, file, index, message, @@ -21,6 +22,7 @@ __all__ = [ "conversation", "dataset", "document", + "enterprise", "file", "index", "message", diff --git a/api/services/errors/enterprise.py b/api/services/errors/enterprise.py new file mode 100644 index 0000000000..c9126199fd --- /dev/null +++ b/api/services/errors/enterprise.py @@ -0,0 +1,45 @@ +"""Enterprise service errors.""" + +from services.errors.base import BaseServiceError + + +class EnterpriseServiceError(BaseServiceError): + """Base exception for enterprise service errors.""" + + def __init__(self, description: str | None = None, status_code: int | None = None): + super().__init__(description) + self.status_code = status_code + + +class EnterpriseAPIError(EnterpriseServiceError): + """Generic enterprise API error (non-2xx response).""" + + pass + + +class EnterpriseAPINotFoundError(EnterpriseServiceError): + """Enterprise API returned 404 Not Found.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=404) + + +class EnterpriseAPIForbiddenError(EnterpriseServiceError): + """Enterprise API returned 403 Forbidden.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=403) + + +class EnterpriseAPIUnauthorizedError(EnterpriseServiceError): + """Enterprise API returned 401 Unauthorized.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=401) + + +class EnterpriseAPIBadRequestError(EnterpriseServiceError): + """Enterprise API returned 400 Bad Request.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=400)