[BugFix] Handle non-contiguous tensors properly when serializing (#16492)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -22,6 +22,10 @@ class MyType:
|
||||
list_of_tensors: list[torch.Tensor]
|
||||
numpy_array: np.ndarray
|
||||
unrecognized: UnrecognizedType
|
||||
small_f_contig_tensor: torch.Tensor
|
||||
large_f_contig_tensor: torch.Tensor
|
||||
small_non_contig_tensor: torch.Tensor
|
||||
large_non_contig_tensor: torch.Tensor
|
||||
|
||||
|
||||
def test_encode_decode():
|
||||
@ -40,6 +44,10 @@ def test_encode_decode():
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
small_f_contig_tensor=torch.rand(5, 4).t(),
|
||||
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
)
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
@ -47,10 +55,10 @@ def test_encode_decode():
|
||||
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
# There should be the main buffer + 2 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 256 bytes.
|
||||
# There should be the main buffer + 4 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||
# The two small tensors are encoded inline.
|
||||
assert len(encoded) == 4
|
||||
assert len(encoded) == 6
|
||||
|
||||
decoded: MyType = decoder.decode(encoded)
|
||||
|
||||
@ -62,7 +70,7 @@ def test_encode_decode():
|
||||
|
||||
encoded2 = encoder.encode_into(obj, preallocated)
|
||||
|
||||
assert len(encoded2) == 4
|
||||
assert len(encoded2) == 6
|
||||
assert encoded2[0] is preallocated
|
||||
|
||||
decoded2: MyType = decoder.decode(encoded2)
|
||||
@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
||||
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
||||
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
||||
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
||||
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
|
||||
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
|
||||
assert torch.equal(obj1.small_non_contig_tensor,
|
||||
obj2.small_non_contig_tensor)
|
||||
assert torch.equal(obj1.large_non_contig_tensor,
|
||||
obj2.large_non_contig_tensor)
|
||||
|
||||
Reference in New Issue
Block a user