diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0dcf02113f..fbd38fc472 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -158,10 +158,8 @@ class MsgpackEncoder: self, obj: torch.Tensor ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - # this creates a copy of the tensor if it's not already contiguous - obj = obj.contiguous() # view the tensor as a 1D array of bytes - arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() + arr = obj.flatten().view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) @@ -169,7 +167,7 @@ class MsgpackEncoder: # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) self.aux_buffers.append(arr.data) - dtype = str(obj.dtype)[6:] # remove 'torch.' prefix + dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data def _encode_nested_tensors(self, nt: NestedTensors) -> Any: @@ -245,7 +243,7 @@ class MsgpackDecoder: # zero-copy decode. We assume the ndarray will not be kept around, # as it now locks the whole received message buffer in memory. buffer = self.aux_buffers[data] if isinstance(data, int) else data - return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + return np.frombuffer(buffer, dtype=dtype).reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: dtype, shape, data = arr @@ -254,12 +252,15 @@ class MsgpackDecoder: # not complain about a readonly memoryview. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) - # Create numpy wrapper around the bytes - arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) + if not buffer: # torch.frombuffer doesn't like empty buffers + assert 0 in shape + return torch.empty(shape, dtype=torch_dtype) + # Create uint8 array + arr = torch.frombuffer(buffer, dtype=torch.uint8) # Convert back to proper shape & type - return torch.from_numpy(arr).view(torch_dtype).view(shape) + return arr.view(torch_dtype).view(shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: decoded_items = []