refactor: convert isinstance chains to match/case (part 4) (#36274)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Evan
2026-05-31 22:44:17 +08:00
committed by GitHub
parent f241ae25be
commit d8571ce965
5 changed files with 215 additions and 189 deletions

View File

@ -203,12 +203,13 @@ def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""
if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))
match response:
case RateLimitGenerator():
answer = process_streaming_response(response)
case Mapping():
answer = process_mapping_response(app, response)
case _:
logger.warning("Unexpected response type: %s", type(response))
return answer

View File

@ -240,30 +240,31 @@ class BaseSession[
self.check_receiver_status()
continue
if response_or_error is None:
raise MCPConnectionError(
ErrorData(
code=500,
message="No response received",
)
)
elif isinstance(response_or_error, HTTPStatusError):
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
match response_or_error:
case None:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
ErrorData(
code=500,
message="No response received",
)
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(message=response_or_error.error.message)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
return result_type.model_validate(response_or_error.result)
case HTTPStatusError():
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
)
case JSONRPCError():
if response_or_error.error.code == 401:
raise MCPAuthError(message=response_or_error.error.message)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
case _:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
@ -316,65 +317,79 @@ class BaseSession[
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
match message:
case HTTPStatusError():
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):
self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
self._received_request(responder)
if not responder.completed:
self._handle_incoming(responder)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
else: # Response or error
response_queue = self._response_streams.get(message.message.root.id)
if response_queue is not None:
response_queue.put(message.message.root)
else:
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
self._handle_incoming(
RuntimeError(f"Received response with an unknown request ID: {message}")
)
case Exception():
self._handle_incoming(message)
case SessionMessage(message=JSONRPCMessage(root=JSONRPCRequest())):
request_root = message.message.root
if not isinstance(request_root, JSONRPCRequest):
continue
validated_request = self._receive_request_type.model_validate(
request_root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=request_root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
self._received_request(responder)
if not responder.completed:
self._handle_incoming(responder)
case SessionMessage(message=JSONRPCMessage(root=JSONRPCNotification())):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e:
# For other validation errors, log and continue
logger.warning(
"Failed to validate notification: %s. Message was: %s", e, message.message.root
)
case _: # Response or error
response_root = message.message.root
if not isinstance(response_root, (JSONRPCResponse, JSONRPCError)):
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
continue
response_queue = self._response_streams.get(response_root.id)
if response_queue is not None:
response_queue.put(response_root)
else:
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception:

View File

@ -1554,12 +1554,13 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
match value:
case str():
value_list = [v.strip() for v in value.split(",") if v.strip()]
case list() | tuple():
value_list = [str(v) for v in value if v is not None]
case _:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True

View File

@ -543,13 +543,16 @@ class Workflow(Base): # bug
def decrypt_func(
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
match var:
case SecretVariable():
return var.model_copy(
update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}
)
case StringVariable() | IntegerVariable() | FloatVariable():
return var
case _:
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
decrypt_func(var) for var in results
@ -1638,31 +1641,32 @@ class WorkflowDraftVariable(Base):
# rather than their serialized forms.
# However, multiple components in the codebase depend on
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict):
if not maybe_file_object(value):
return cast(Any, value)
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
return build_file_from_stored_mapping(
file_mapping=cast(dict[str, Any], value),
tenant_id=tenant_id,
)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return cast(Any, value)
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
file_list: list[File] = []
for item in value_list:
file_list.append(
build_file_from_stored_mapping(
file_mapping=cast(dict[str, Any], item),
tenant_id=tenant_id,
)
match value:
case dict():
if not maybe_file_object(value):
return cast(Any, value)
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
return build_file_from_stored_mapping(
file_mapping=cast(dict[str, Any], value),
tenant_id=tenant_id,
)
return cast(Any, file_list)
else:
return cast(Any, value)
case list() if value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return cast(Any, value)
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
file_list: list[File] = []
for item in value_list:
file_list.append(
build_file_from_stored_mapping(
file_mapping=cast(dict[str, Any], item),
tenant_id=tenant_id,
)
)
return cast(Any, file_list)
case _:
return cast(Any, value)
def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment:
# Persisted draft variable rows may contain historical file payloads.
@ -1671,13 +1675,14 @@ class WorkflowDraftVariable(Base):
# serialized JSON blob.
match segment_type:
case SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = self._rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
match value:
case File():
return build_segment_with_type(segment_type, value)
case dict():
file = self._rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
case _:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
case SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
@ -1692,25 +1697,26 @@ class WorkflowDraftVariable(Base):
# structural reconstruction. Persisted draft-variable payloads should go
# through `build_segment_from_serialized_value()` so file metadata is
# rebuilt from canonical storage records.
if isinstance(value, dict):
if not maybe_file_object(value):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return cast(Any, value)
file_list: list[File] = []
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
match value:
case dict():
if not maybe_file_object(value):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
case list() if value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return cast(Any, value)
file_list: list[File] = []
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
case _:
return cast(Any, value)
@classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
@ -1719,13 +1725,14 @@ class WorkflowDraftVariable(Base):
# their serialized dictionary or list representations, respectively.
match segment_type:
case SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
match value:
case File():
return build_segment_with_type(segment_type, value)
case dict():
file = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
case _:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
case SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
@ -2099,17 +2106,18 @@ class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase):
@classmethod
def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
pause_id=pause_id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=pause_reason.form_id,
node_id=pause_reason.node_id,
)
elif isinstance(pause_reason, SchedulingPause):
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
else:
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
match pause_reason:
case HumanInputRequired():
return cls(
pause_id=pause_id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=pause_reason.form_id,
node_id=pause_reason.node_id,
)
case SchedulingPause():
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
case _:
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
def to_entity(self) -> PauseReason:
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:

View File

@ -52,20 +52,21 @@ class _EndUser(BaseModel):
def _get_user_type_descriminator(value: Any):
if isinstance(value, (_Account, _EndUser)):
return value.TYPE
elif isinstance(value, dict):
user_type_str = value.get("TYPE")
if user_type_str is None:
match value:
case _Account() | _EndUser():
return value.TYPE
case dict():
user_type_str = value.get("TYPE")
if user_type_str is None:
return None
try:
user_type = _UserType(user_type_str)
except ValueError:
return None
return user_type
case _:
# return None if the discriminator value isn't found
return None
try:
user_type = _UserType(user_type_str)
except ValueError:
return None
return user_type
else:
# return None if the discriminator value isn't found
return None
type User = Annotated[
@ -221,17 +222,17 @@ class _AppRunner:
def _resolve_user(self) -> Account | EndUser:
user_params = self._exec_params.user
if isinstance(user_params, _EndUser):
with self._session() as session:
return session.get(EndUser, user_params.end_user_id)
elif not isinstance(user_params, _Account):
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
with self._session() as session:
user: Account = session.get(Account, user_params.user_id)
user.set_tenant_id(self._exec_params.tenant_id)
return user
match user_params:
case _EndUser():
with self._session() as session:
return session.get(EndUser, user_params.end_user_id)
case _Account():
with self._session() as session:
user: Account = session.get(Account, user_params.user_id)
user.set_tenant_id(self._exec_params.tenant_id)
return user
case _:
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: