* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
458 lines
17 KiB
Python
458 lines
17 KiB
Python
import numpy as np
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cute.runtime import from_dlpack
|
|
import torch
|
|
|
|
|
|
@cute.jit
|
|
def print_tensor_dlpack(src: cute.Tensor):
|
|
print(src)
|
|
cute.print_tensor(src)
|
|
|
|
|
|
# Sparse emulation
|
|
class SparseEmulation:
|
|
def __init__(self, M: int, N: int, K: int, L: int):
|
|
self.M = M
|
|
self.N = N
|
|
self.K = K
|
|
self.L = L
|
|
|
|
@cute.jit
|
|
def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
|
"""Sparse emulation"""
|
|
num_threads = 128
|
|
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
|
block = (num_threads, 1, 1)
|
|
self.kernel(a, b, d, e).launch(grid=grid, block=block)
|
|
return
|
|
|
|
@cute.kernel
|
|
def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
|
"""CUDA kernel to emulate sparse tensor core"""
|
|
tidx, tidy, tidz = cute.arch.thread_idx()
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
|
|
row_idx = tidx + bidx * self.M
|
|
meta_idx = self.K // 4 // 8
|
|
if row_idx < self.M:
|
|
# each thread process 1 row
|
|
for col in range(self.N):
|
|
# each meta_idx stands for 32 elements
|
|
for e_idx in range(meta_idx):
|
|
meta_val = e[(row_idx, e_idx)]
|
|
for k in range(8):
|
|
# each k stands for 4 elements
|
|
meta_row = (meta_val >> (k * 4)) & 0xF
|
|
idx0 = meta_row & 0x3
|
|
idx1 = (meta_row >> 2) & 0x3
|
|
# calculate the idx in b tensor which has value in A tensor
|
|
km = e_idx * 16 + k * 2
|
|
km_1 = km + 1
|
|
kn = e_idx * 32 + k * 4 + idx0
|
|
kn_1 = e_idx * 32 + k * 4 + idx1
|
|
d[row_idx, col] += a[row_idx, km] * b[col, kn]
|
|
d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1]
|
|
return
|
|
|
|
|
|
# Compressor
|
|
# compress a sparse tensor to a dense tensor && generate metadata
|
|
class Compressor:
|
|
def __init__(self, M: int, K: int, L: int):
|
|
self.M = M
|
|
self.K = K
|
|
self.L = L
|
|
self.pos_map = {
|
|
0x4: [0, 1],
|
|
0x8: [0, 2],
|
|
0xC: [0, 3],
|
|
0x9: [1, 2],
|
|
0xD: [1, 3],
|
|
0xE: [2, 3],
|
|
}
|
|
|
|
@cute.jit
|
|
def _init__(self, a: cute.Tensor):
|
|
self.__init__(a.shape[0], a.shape[1], a.shape[2])
|
|
|
|
def compress(self, a, a_compressed, meta, run_on_cpu: bool):
|
|
if run_on_cpu:
|
|
if a.device.type != "cpu":
|
|
raise ValueError("a must be on cpu")
|
|
return self.__compress_on_cpu(a, a_compressed, meta)
|
|
else:
|
|
if a.device.type != "cuda":
|
|
raise ValueError("a must be on cuda")
|
|
return self.__compress_on_cuda(a, a_compressed, meta)
|
|
|
|
def __compress_on_cpu(self, a, a_compressed, meta):
|
|
"""
|
|
compress the tensor on cpu
|
|
# Convert to 4-bit metadata value
|
|
# The metadata value represents which 2 elements are non-zero
|
|
# 0x4: [1,1,0,0] - first two elements are non-zero
|
|
# 0x8: [1,0,1,0] - first and third elements are non-zero
|
|
# 0xC: [1,0,0,1] - first and fourth elements are non-zero
|
|
# 0x9: [0,1,1,0] - second and third elements are non-zero
|
|
# 0xD: [0,1,0,1] - second and fourth elements are non-zero
|
|
# 0xE: [0,0,1,1] - third and fourth elements are non-zero
|
|
# special case:
|
|
# [0,0,0,0] == [0,0,1,1]
|
|
# [1,0,0,0] == [1,0,0,1]
|
|
# [0,1,0,0] == [0,1,0,1]
|
|
# [0,0,1,0] == [0,0,1,1]
|
|
# [0,0,0,1] == [0,0,1,1]
|
|
"""
|
|
M, K = a.shape
|
|
assert a_compressed.shape == (
|
|
M,
|
|
K // 2,
|
|
), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}"
|
|
assert meta.shape == (
|
|
M,
|
|
K // 4 // 8,
|
|
), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}"
|
|
for m in range(M):
|
|
k_meta = 0
|
|
for k in range(0, K, 4):
|
|
chunk = a[m, k : k + 4]
|
|
|
|
non_zero_indices = torch.nonzero(chunk).squeeze()
|
|
meta_val = 0xE
|
|
if torch.equal(non_zero_indices, torch.tensor([0, 1])):
|
|
meta_val = 0x4
|
|
elif torch.equal(non_zero_indices, torch.tensor([0, 2])):
|
|
meta_val = 0x8
|
|
elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal(
|
|
non_zero_indices, torch.tensor(0)
|
|
):
|
|
meta_val = 0xC
|
|
elif torch.equal(non_zero_indices, torch.tensor([1, 2])):
|
|
meta_val = 0x9
|
|
elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal(
|
|
non_zero_indices, torch.tensor(1)
|
|
):
|
|
meta_val = 0xD
|
|
elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal(
|
|
non_zero_indices, torch.tensor(2)
|
|
):
|
|
meta_val = 0xE
|
|
elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal(
|
|
non_zero_indices, torch.tensor(3)
|
|
):
|
|
meta_val = 0xE
|
|
else:
|
|
raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}")
|
|
meta_idx = k // 4 // 8
|
|
meta_bit_pos = (k // 4) % 8
|
|
if k_meta == meta_idx:
|
|
k_meta = meta_idx + 1
|
|
meta[m, meta_idx] = 0
|
|
meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4)
|
|
compressed_idx = k // 2
|
|
index = self.pos_map[meta_val]
|
|
a_compressed[m, compressed_idx] = chunk[index[0]]
|
|
a_compressed[m, compressed_idx + 1] = chunk[index[1]]
|
|
|
|
def __compress_on_cuda(self, a, a_compressed, meta):
|
|
"""
|
|
compress the tensor on cuda
|
|
"""
|
|
a_tensor = from_dlpack(a)
|
|
a_compressed_tensor = from_dlpack(a_compressed)
|
|
meta_tensor = from_dlpack(meta)
|
|
self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor)
|
|
return
|
|
|
|
@cute.jit
|
|
def compress_on_cuda_impl(
|
|
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
|
):
|
|
"""Compress the input tensor using the metadata"""
|
|
num_threads = 128
|
|
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
|
block = (num_threads, 1, 1)
|
|
self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block)
|
|
|
|
@cute.kernel
|
|
def compressor_impl(
|
|
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
|
):
|
|
"""CUDA kernel to compress the tensor"""
|
|
tidx, tidy, tidz = cute.arch.thread_idx()
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
m = a.shape[0]
|
|
k = a.shape[1]
|
|
|
|
# each thread process 1 row
|
|
row_idx = tidx + bidx * self.M
|
|
meta_idx = self.K // 4 // 8
|
|
if row_idx < self.M:
|
|
# each meta_idx stands for 32 elements
|
|
for i in range(meta_idx):
|
|
meta[row_idx, i] = 0
|
|
# each k stands for 4 elements
|
|
for j in range(8):
|
|
val = a[row_idx, i * 32 + j * 4]
|
|
val_1 = a[row_idx, i * 32 + j * 4 + 1]
|
|
val_2 = a[row_idx, i * 32 + j * 4 + 2]
|
|
val_3 = a[row_idx, i * 32 + j * 4 + 3]
|
|
value_idx = 0
|
|
value_idx_1 = 0
|
|
value_idx_2 = 0
|
|
value_idx_3 = 0
|
|
pos0 = 0
|
|
pos1 = 0
|
|
if val != 0:
|
|
value_idx = 1
|
|
pos0 = 0
|
|
if val_1 != 0:
|
|
value_idx_1 = 1
|
|
if val_2 != 0:
|
|
value_idx_2 = 1
|
|
if val_3 != 0:
|
|
value_idx_3 = 1
|
|
pos = [value_idx, value_idx_1, value_idx_2, value_idx_3]
|
|
tmp = 0
|
|
if pos == [0, 0, 0, 0]:
|
|
tmp = 0xE
|
|
pos0 = 2
|
|
pos1 = 3
|
|
elif pos == [1, 0, 0, 0]:
|
|
tmp = 0xC
|
|
pos0 = 0
|
|
pos1 = 3
|
|
elif pos == [0, 1, 0, 0]:
|
|
tmp = 0xD
|
|
pos0 = 1
|
|
pos1 = 3
|
|
elif pos == [0, 0, 1, 0]:
|
|
tmp = 0xE
|
|
pos0 = 2
|
|
pos1 = 3
|
|
elif pos == [0, 0, 0, 1]:
|
|
tmp = 0xE
|
|
pos0 = 2
|
|
pos1 = 3
|
|
elif pos == [1, 1, 0, 0]:
|
|
tmp = 0x4
|
|
pos0 = 0
|
|
pos1 = 1
|
|
elif pos == [1, 0, 1, 0]:
|
|
tmp = 0x8
|
|
pos0 = 0
|
|
pos1 = 2
|
|
elif pos == [1, 0, 0, 1]:
|
|
tmp = 0xC
|
|
pos0 = 0
|
|
pos1 = 3
|
|
elif pos == [0, 1, 1, 0]:
|
|
tmp = 0x9
|
|
pos0 = 1
|
|
pos1 = 2
|
|
elif pos == [0, 1, 0, 1]:
|
|
tmp = 0xD
|
|
pos0 = 1
|
|
pos1 = 3
|
|
elif pos == [0, 0, 1, 1]:
|
|
tmp = 0xE
|
|
pos0 = 2
|
|
pos1 = 3
|
|
# cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp)
|
|
meta[row_idx, i] |= tmp << (j * 4)
|
|
|
|
a_compressed[row_idx, i * 16 + j * 2] = a[
|
|
row_idx, i * 32 + j * 4 + pos0
|
|
]
|
|
a_compressed[row_idx, i * 16 + j * 2 + 1] = a[
|
|
row_idx, i * 32 + j * 4 + pos1
|
|
]
|
|
|
|
return
|
|
|
|
|
|
# SparseUtils is used to generate sparse tensor
|
|
# format torch.Tensor
|
|
class SparseUtils:
|
|
#!brief: SparseUtils is used to generate sparse tensor
|
|
#!param: M: int, K: int, L: int, dtype: cutlass.DataType
|
|
def __init__(self, M: int, K: int, L: int, dtype):
|
|
self.M = M
|
|
self.K = K
|
|
self.L = L
|
|
self.dtype = dtype
|
|
self.meta_data = self._generate_meta_data_4_2()
|
|
self._use_specific_meta_data = False
|
|
|
|
#!brief: cast cutlass.DataType to torch.Tensor
|
|
def _get_type(self):
|
|
if self.dtype == cutlass.Float16:
|
|
return torch.float16
|
|
elif self.dtype == cutlass.Float32:
|
|
return torch.float32
|
|
elif self.dtype == cutlass.Int8:
|
|
return torch.int8
|
|
else:
|
|
raise ValueError(f"Unsupported dtype: {self.dtype}")
|
|
|
|
def _generate_meta_data_4_2(self):
|
|
# metadata for 4:2 sparse will in range( 4,8,9,c,d,e)
|
|
# represents
|
|
# 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4
|
|
# 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8
|
|
# 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c
|
|
# 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9
|
|
# 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d
|
|
# 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e
|
|
meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
|
|
# 4:2 sparse, so each chunk is 4 elements, map to 4 bits
|
|
K_NumChunk = self.K // 4
|
|
meta_data = np.random.choice(
|
|
meta_value, size=(self.M, K_NumChunk), replace=True
|
|
)
|
|
meta_data = torch.from_numpy(
|
|
np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk)
|
|
)
|
|
return meta_data
|
|
|
|
#!brief: pack meta data
|
|
def _pack_meta_data(self):
|
|
tmp = []
|
|
K_NumChunk = self.K // 4
|
|
for i in range(self.M):
|
|
for j in range(K_NumChunk // 8):
|
|
v = 0
|
|
for k in range(8):
|
|
vv = int(self.meta_data[i, j * 8 + k] & 0xF)
|
|
tt = vv << (k * 4)
|
|
v = v | tt
|
|
tmp.append(v)
|
|
# debug print
|
|
# print([hex(vt) for vt in tmp])
|
|
result = torch.from_numpy(
|
|
np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8)
|
|
)
|
|
return result
|
|
|
|
#!brief: use specific meta data
|
|
def use_specific_meta_data(self, meta_data: torch.Tensor = None):
|
|
if meta_data is not None:
|
|
self.meta_data = meta_data
|
|
self._use_specific_meta_data = True
|
|
|
|
#!brief: generate sparse tensor with tensor
|
|
#!param: a: torch.Tensor
|
|
#!param: run_on_cpu: bool
|
|
#!return: torch.Tensor
|
|
def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu):
|
|
if run_on_cpu:
|
|
if a.device.type != "cpu":
|
|
raise ValueError("a must be on cpu")
|
|
return self.__generate_sparse_tensor_cpu(a)
|
|
else:
|
|
if a.device.type != "cuda":
|
|
raise ValueError("a must be on cuda")
|
|
a_tensor = from_dlpack(a)
|
|
packed_meta_data = self._pack_meta_data()
|
|
meta_tensor = from_dlpack(packed_meta_data.cuda())
|
|
self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor)
|
|
return a
|
|
|
|
#!brief: generate sparse tensor
|
|
#!param: run_on_cpu: bool
|
|
#!return: torch.Tensor
|
|
def generate_4_2_sparse_tensor(self, run_on_cpu):
|
|
dtype = self._get_type()
|
|
a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype)
|
|
if run_on_cpu:
|
|
return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu)
|
|
else:
|
|
return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu)
|
|
|
|
#!brief: generate sparse tensor on cpu
|
|
#!param: a: torch.Tensor
|
|
#!return: torch.Tensor
|
|
def __generate_sparse_tensor_cpu(self, a):
|
|
if not self._use_specific_meta_data:
|
|
for m in range(self.M):
|
|
for k in range(0, self.K, 4):
|
|
# random choose 2 zero positions
|
|
zero_indices = torch.randperm(4)[:2]
|
|
a[m, k + zero_indices[0]] = 0
|
|
a[m, k + zero_indices[1]] = 0
|
|
return a
|
|
else:
|
|
# use specific meta data
|
|
tensor_mask = []
|
|
for i in range(self.M):
|
|
for j in range(self.K // 4):
|
|
meta_val = self.meta_data[i, j]
|
|
tmp = []
|
|
if meta_val == 0x4:
|
|
tmp = [1, 1, 0, 0]
|
|
elif meta_val == 0x8:
|
|
tmp = [1, 0, 1, 0]
|
|
elif meta_val == 0xC:
|
|
tmp = [1, 0, 0, 1]
|
|
elif meta_val == 0x9:
|
|
tmp = [0, 1, 1, 0]
|
|
elif meta_val == 0xD:
|
|
tmp = [0, 1, 0, 1]
|
|
elif meta_val == 0xE:
|
|
tmp = [0, 0, 1, 1]
|
|
tensor_mask.extend(tmp)
|
|
a = torch.reshape(a, (-1,))
|
|
mask = torch.tensor(tensor_mask)
|
|
a = a * mask
|
|
a = torch.reshape(a, (self.M, self.K))
|
|
return a
|
|
|
|
@cute.jit
|
|
def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor):
|
|
"""Generate a sparse tensor from a dense tensor using metadata"""
|
|
assert a.shape[0] == self.M and a.shape[1] == self.K
|
|
assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8
|
|
num_threads = 128
|
|
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
|
block = (num_threads, 1, 1)
|
|
self.kernel(a, meta).launch(grid=grid, block=block)
|
|
|
|
@cute.kernel
|
|
def kernel(self, a: cute.Tensor, meta: cute.Tensor):
|
|
"""Apply sparsity mask to input tensor using metadata"""
|
|
tidx, tidy, tidz = cute.arch.thread_idx()
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
|
|
# each thread process 1 ro
|
|
row_idx = tidx + bidx * self.M
|
|
meta_idx = self.K // 4 // 8
|
|
# each thread process 1 row
|
|
if row_idx < self.M:
|
|
# iterate over each chunk(32 elements)
|
|
for i in range(meta_idx):
|
|
meta_val = meta[(row_idx, i)]
|
|
# iterate over each sparse pattern(4 elements)
|
|
for j in range(8):
|
|
meta_row = (meta_val >> (j * 4)) & 0xF
|
|
idx0 = meta_row & 0x3
|
|
idx1 = (meta_row >> 2) & 0x3
|
|
r_id0 = 0
|
|
r_id1 = 0
|
|
# r_id is the idx that value is 0
|
|
if idx0 >= 2 and idx1 >= 2:
|
|
r_id0 = 0
|
|
r_id1 = 1
|
|
elif idx0 <= 1 and idx1 <= 1:
|
|
r_id0 = 2
|
|
r_id1 = 3
|
|
else:
|
|
r_id0 = idx0 ^ 0b1
|
|
r_id1 = idx1 ^ 0b1
|
|
row_id0 = r_id0 + i * 32 + j * 4
|
|
row_id1 = r_id1 + i * 32 + j * 4
|
|
a[row_idx, row_id0] = self.dtype(0.0)
|
|
a[row_idx, row_id1] = self.dtype(0.0)
|
|
return
|