mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 06:28:14 +08:00
refactor: convert isinstance chains to match/case (part 4) (#36274)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@ -203,12 +203,13 @@ def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
"""Extract answer from app generate response"""
|
||||
answer = ""
|
||||
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
answer = process_streaming_response(response)
|
||||
elif isinstance(response, Mapping):
|
||||
answer = process_mapping_response(app, response)
|
||||
else:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
match response:
|
||||
case RateLimitGenerator():
|
||||
answer = process_streaming_response(response)
|
||||
case Mapping():
|
||||
answer = process_mapping_response(app, response)
|
||||
case _:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
@ -240,30 +240,31 @@ class BaseSession[
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, HTTPStatusError):
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
match response_or_error:
|
||||
case None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
case HTTPStatusError():
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
)
|
||||
case JSONRPCError():
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
case _:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
@ -316,65 +317,79 @@ class BaseSession[
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
match message:
|
||||
case HTTPStatusError():
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder.completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification) # type: ignore[arg-type]
|
||||
self._handle_incoming(notification) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(message.message.root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
self._handle_incoming(
|
||||
RuntimeError(f"Received response with an unknown request ID: {message}")
|
||||
)
|
||||
case Exception():
|
||||
self._handle_incoming(message)
|
||||
case SessionMessage(message=JSONRPCMessage(root=JSONRPCRequest())):
|
||||
request_root = message.message.root
|
||||
if not isinstance(request_root, JSONRPCRequest):
|
||||
continue
|
||||
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
request_root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||
request_id=request_root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder.completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
case SessionMessage(message=JSONRPCMessage(root=JSONRPCNotification())):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification) # type: ignore[arg-type]
|
||||
self._handle_incoming(notification) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logger.warning(
|
||||
"Failed to validate notification: %s. Message was: %s", e, message.message.root
|
||||
)
|
||||
case _: # Response or error
|
||||
response_root = message.message.root
|
||||
if not isinstance(response_root, (JSONRPCResponse, JSONRPCError)):
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
continue
|
||||
|
||||
response_queue = self._response_streams.get(response_root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(response_root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
|
||||
@ -1554,12 +1554,13 @@ class DatasetRetrieval:
|
||||
case "≥" | ">=":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case "in" | "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
match value:
|
||||
case str():
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
case list() | tuple():
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
case _:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
# `field in []` is False, `field not in []` is True
|
||||
|
||||
@ -543,13 +543,16 @@ class Workflow(Base): # bug
|
||||
def decrypt_func(
|
||||
var: VariableBase,
|
||||
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||
return var
|
||||
else:
|
||||
# Other variable types are not supported for environment variables
|
||||
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
|
||||
match var:
|
||||
case SecretVariable():
|
||||
return var.model_copy(
|
||||
update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}
|
||||
)
|
||||
case StringVariable() | IntegerVariable() | FloatVariable():
|
||||
return var
|
||||
case _:
|
||||
# Other variable types are not supported for environment variables
|
||||
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
|
||||
|
||||
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
|
||||
decrypt_func(var) for var in results
|
||||
@ -1638,31 +1641,32 @@ class WorkflowDraftVariable(Base):
|
||||
# rather than their serialized forms.
|
||||
# However, multiple components in the codebase depend on
|
||||
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
|
||||
if isinstance(value, dict):
|
||||
if not maybe_file_object(value):
|
||||
return cast(Any, value)
|
||||
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
|
||||
return build_file_from_stored_mapping(
|
||||
file_mapping=cast(dict[str, Any], value),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
elif isinstance(value, list) and value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return cast(Any, value)
|
||||
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
file_list.append(
|
||||
build_file_from_stored_mapping(
|
||||
file_mapping=cast(dict[str, Any], item),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
match value:
|
||||
case dict():
|
||||
if not maybe_file_object(value):
|
||||
return cast(Any, value)
|
||||
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
|
||||
return build_file_from_stored_mapping(
|
||||
file_mapping=cast(dict[str, Any], value),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return cast(Any, value)
|
||||
case list() if value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return cast(Any, value)
|
||||
tenant_id = _resolve_workflow_app_tenant_id(self.app_id)
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
file_list.append(
|
||||
build_file_from_stored_mapping(
|
||||
file_mapping=cast(dict[str, Any], item),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
return cast(Any, file_list)
|
||||
case _:
|
||||
return cast(Any, value)
|
||||
|
||||
def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment:
|
||||
# Persisted draft variable rows may contain historical file payloads.
|
||||
@ -1671,13 +1675,14 @@ class WorkflowDraftVariable(Base):
|
||||
# serialized JSON blob.
|
||||
match segment_type:
|
||||
case SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
match value:
|
||||
case File():
|
||||
return build_segment_with_type(segment_type, value)
|
||||
case dict():
|
||||
file = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
case _:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
@ -1692,25 +1697,26 @@ class WorkflowDraftVariable(Base):
|
||||
# structural reconstruction. Persisted draft-variable payloads should go
|
||||
# through `build_segment_from_serialized_value()` so file metadata is
|
||||
# rebuilt from canonical storage records.
|
||||
if isinstance(value, dict):
|
||||
if not maybe_file_object(value):
|
||||
return cast(Any, value)
|
||||
normalized_file = dict(value)
|
||||
normalized_file.pop("tenant_id", None)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
|
||||
elif isinstance(value, list) and value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return cast(Any, value)
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
normalized_file = dict(cast(dict[str, Any], item))
|
||||
match value:
|
||||
case dict():
|
||||
if not maybe_file_object(value):
|
||||
return cast(Any, value)
|
||||
normalized_file = dict(value)
|
||||
normalized_file.pop("tenant_id", None)
|
||||
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return cast(Any, value)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
|
||||
case list() if value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return cast(Any, value)
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
normalized_file = dict(cast(dict[str, Any], item))
|
||||
normalized_file.pop("tenant_id", None)
|
||||
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
|
||||
return cast(Any, file_list)
|
||||
case _:
|
||||
return cast(Any, value)
|
||||
|
||||
@classmethod
|
||||
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
|
||||
@ -1719,13 +1725,14 @@ class WorkflowDraftVariable(Base):
|
||||
# their serialized dictionary or list representations, respectively.
|
||||
match segment_type:
|
||||
case SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
match value:
|
||||
case File():
|
||||
return build_segment_with_type(segment_type, value)
|
||||
case dict():
|
||||
file = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
case _:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
@ -2099,17 +2106,18 @@ class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase):
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
pause_id=pause_id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=pause_reason.form_id,
|
||||
node_id=pause_reason.node_id,
|
||||
)
|
||||
elif isinstance(pause_reason, SchedulingPause):
|
||||
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
match pause_reason:
|
||||
case HumanInputRequired():
|
||||
return cls(
|
||||
pause_id=pause_id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=pause_reason.form_id,
|
||||
node_id=pause_reason.node_id,
|
||||
)
|
||||
case SchedulingPause():
|
||||
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
|
||||
case _:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
|
||||
def to_entity(self) -> PauseReason:
|
||||
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
|
||||
@ -52,20 +52,21 @@ class _EndUser(BaseModel):
|
||||
|
||||
|
||||
def _get_user_type_descriminator(value: Any):
|
||||
if isinstance(value, (_Account, _EndUser)):
|
||||
return value.TYPE
|
||||
elif isinstance(value, dict):
|
||||
user_type_str = value.get("TYPE")
|
||||
if user_type_str is None:
|
||||
match value:
|
||||
case _Account() | _EndUser():
|
||||
return value.TYPE
|
||||
case dict():
|
||||
user_type_str = value.get("TYPE")
|
||||
if user_type_str is None:
|
||||
return None
|
||||
try:
|
||||
user_type = _UserType(user_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return user_type
|
||||
case _:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
try:
|
||||
user_type = _UserType(user_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return user_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
type User = Annotated[
|
||||
@ -221,17 +222,17 @@ class _AppRunner:
|
||||
def _resolve_user(self) -> Account | EndUser:
|
||||
user_params = self._exec_params.user
|
||||
|
||||
if isinstance(user_params, _EndUser):
|
||||
with self._session() as session:
|
||||
return session.get(EndUser, user_params.end_user_id)
|
||||
elif not isinstance(user_params, _Account):
|
||||
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
|
||||
|
||||
with self._session() as session:
|
||||
user: Account = session.get(Account, user_params.user_id)
|
||||
user.set_tenant_id(self._exec_params.tenant_id)
|
||||
|
||||
return user
|
||||
match user_params:
|
||||
case _EndUser():
|
||||
with self._session() as session:
|
||||
return session.get(EndUser, user_params.end_user_id)
|
||||
case _Account():
|
||||
with self._session() as session:
|
||||
user: Account = session.get(Account, user_params.user_id)
|
||||
user.set_tenant_id(self._exec_params.tenant_id)
|
||||
return user
|
||||
case _:
|
||||
raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}")
|
||||
|
||||
|
||||
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
|
||||
|
||||
Reference in New Issue
Block a user