v4.3 update. (#2709)
* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
457
examples/python/CuTeDSL/utils/sparse_utils.py
Normal file
457
examples/python/CuTeDSL/utils/sparse_utils.py
Normal file
@ -0,0 +1,457 @@
|
||||
import numpy as np
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import torch
|
||||
|
||||
|
||||
@cute.jit
|
||||
def print_tensor_dlpack(src: cute.Tensor):
|
||||
print(src)
|
||||
cute.print_tensor(src)
|
||||
|
||||
|
||||
# Sparse emulation
|
||||
class SparseEmulation:
|
||||
def __init__(self, M: int, N: int, K: int, L: int):
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.K = K
|
||||
self.L = L
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
||||
"""Sparse emulation"""
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.kernel(a, b, d, e).launch(grid=grid, block=block)
|
||||
return
|
||||
|
||||
@cute.kernel
|
||||
def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
||||
"""CUDA kernel to emulate sparse tensor core"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
if row_idx < self.M:
|
||||
# each thread process 1 row
|
||||
for col in range(self.N):
|
||||
# each meta_idx stands for 32 elements
|
||||
for e_idx in range(meta_idx):
|
||||
meta_val = e[(row_idx, e_idx)]
|
||||
for k in range(8):
|
||||
# each k stands for 4 elements
|
||||
meta_row = (meta_val >> (k * 4)) & 0xF
|
||||
idx0 = meta_row & 0x3
|
||||
idx1 = (meta_row >> 2) & 0x3
|
||||
# calculate the idx in b tensor which has value in A tensor
|
||||
km = e_idx * 16 + k * 2
|
||||
km_1 = km + 1
|
||||
kn = e_idx * 32 + k * 4 + idx0
|
||||
kn_1 = e_idx * 32 + k * 4 + idx1
|
||||
d[row_idx, col] += a[row_idx, km] * b[col, kn]
|
||||
d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1]
|
||||
return
|
||||
|
||||
|
||||
# Compressor
|
||||
# compress a sparse tensor to a dense tensor && generate metadata
|
||||
class Compressor:
|
||||
def __init__(self, M: int, K: int, L: int):
|
||||
self.M = M
|
||||
self.K = K
|
||||
self.L = L
|
||||
self.pos_map = {
|
||||
0x4: [0, 1],
|
||||
0x8: [0, 2],
|
||||
0xC: [0, 3],
|
||||
0x9: [1, 2],
|
||||
0xD: [1, 3],
|
||||
0xE: [2, 3],
|
||||
}
|
||||
|
||||
@cute.jit
|
||||
def _init__(self, a: cute.Tensor):
|
||||
self.__init__(a.shape[0], a.shape[1], a.shape[2])
|
||||
|
||||
def compress(self, a, a_compressed, meta, run_on_cpu: bool):
|
||||
if run_on_cpu:
|
||||
if a.device.type != "cpu":
|
||||
raise ValueError("a must be on cpu")
|
||||
return self.__compress_on_cpu(a, a_compressed, meta)
|
||||
else:
|
||||
if a.device.type != "cuda":
|
||||
raise ValueError("a must be on cuda")
|
||||
return self.__compress_on_cuda(a, a_compressed, meta)
|
||||
|
||||
def __compress_on_cpu(self, a, a_compressed, meta):
|
||||
"""
|
||||
compress the tensor on cpu
|
||||
# Convert to 4-bit metadata value
|
||||
# The metadata value represents which 2 elements are non-zero
|
||||
# 0x4: [1,1,0,0] - first two elements are non-zero
|
||||
# 0x8: [1,0,1,0] - first and third elements are non-zero
|
||||
# 0xC: [1,0,0,1] - first and fourth elements are non-zero
|
||||
# 0x9: [0,1,1,0] - second and third elements are non-zero
|
||||
# 0xD: [0,1,0,1] - second and fourth elements are non-zero
|
||||
# 0xE: [0,0,1,1] - third and fourth elements are non-zero
|
||||
# special case:
|
||||
# [0,0,0,0] == [0,0,1,1]
|
||||
# [1,0,0,0] == [1,0,0,1]
|
||||
# [0,1,0,0] == [0,1,0,1]
|
||||
# [0,0,1,0] == [0,0,1,1]
|
||||
# [0,0,0,1] == [0,0,1,1]
|
||||
"""
|
||||
M, K = a.shape
|
||||
assert a_compressed.shape == (
|
||||
M,
|
||||
K // 2,
|
||||
), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}"
|
||||
assert meta.shape == (
|
||||
M,
|
||||
K // 4 // 8,
|
||||
), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}"
|
||||
for m in range(M):
|
||||
k_meta = 0
|
||||
for k in range(0, K, 4):
|
||||
chunk = a[m, k : k + 4]
|
||||
|
||||
non_zero_indices = torch.nonzero(chunk).squeeze()
|
||||
meta_val = 0xE
|
||||
if torch.equal(non_zero_indices, torch.tensor([0, 1])):
|
||||
meta_val = 0x4
|
||||
elif torch.equal(non_zero_indices, torch.tensor([0, 2])):
|
||||
meta_val = 0x8
|
||||
elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(0)
|
||||
):
|
||||
meta_val = 0xC
|
||||
elif torch.equal(non_zero_indices, torch.tensor([1, 2])):
|
||||
meta_val = 0x9
|
||||
elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(1)
|
||||
):
|
||||
meta_val = 0xD
|
||||
elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(2)
|
||||
):
|
||||
meta_val = 0xE
|
||||
elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(3)
|
||||
):
|
||||
meta_val = 0xE
|
||||
else:
|
||||
raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}")
|
||||
meta_idx = k // 4 // 8
|
||||
meta_bit_pos = (k // 4) % 8
|
||||
if k_meta == meta_idx:
|
||||
k_meta = meta_idx + 1
|
||||
meta[m, meta_idx] = 0
|
||||
meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4)
|
||||
compressed_idx = k // 2
|
||||
index = self.pos_map[meta_val]
|
||||
a_compressed[m, compressed_idx] = chunk[index[0]]
|
||||
a_compressed[m, compressed_idx + 1] = chunk[index[1]]
|
||||
|
||||
def __compress_on_cuda(self, a, a_compressed, meta):
|
||||
"""
|
||||
compress the tensor on cuda
|
||||
"""
|
||||
a_tensor = from_dlpack(a)
|
||||
a_compressed_tensor = from_dlpack(a_compressed)
|
||||
meta_tensor = from_dlpack(meta)
|
||||
self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor)
|
||||
return
|
||||
|
||||
@cute.jit
|
||||
def compress_on_cuda_impl(
|
||||
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
||||
):
|
||||
"""Compress the input tensor using the metadata"""
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block)
|
||||
|
||||
@cute.kernel
|
||||
def compressor_impl(
|
||||
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
||||
):
|
||||
"""CUDA kernel to compress the tensor"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
m = a.shape[0]
|
||||
k = a.shape[1]
|
||||
|
||||
# each thread process 1 row
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
if row_idx < self.M:
|
||||
# each meta_idx stands for 32 elements
|
||||
for i in range(meta_idx):
|
||||
meta[row_idx, i] = 0
|
||||
# each k stands for 4 elements
|
||||
for j in range(8):
|
||||
val = a[row_idx, i * 32 + j * 4]
|
||||
val_1 = a[row_idx, i * 32 + j * 4 + 1]
|
||||
val_2 = a[row_idx, i * 32 + j * 4 + 2]
|
||||
val_3 = a[row_idx, i * 32 + j * 4 + 3]
|
||||
value_idx = 0
|
||||
value_idx_1 = 0
|
||||
value_idx_2 = 0
|
||||
value_idx_3 = 0
|
||||
pos0 = 0
|
||||
pos1 = 0
|
||||
if val != 0:
|
||||
value_idx = 1
|
||||
pos0 = 0
|
||||
if val_1 != 0:
|
||||
value_idx_1 = 1
|
||||
if val_2 != 0:
|
||||
value_idx_2 = 1
|
||||
if val_3 != 0:
|
||||
value_idx_3 = 1
|
||||
pos = [value_idx, value_idx_1, value_idx_2, value_idx_3]
|
||||
tmp = 0
|
||||
if pos == [0, 0, 0, 0]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [1, 0, 0, 0]:
|
||||
tmp = 0xC
|
||||
pos0 = 0
|
||||
pos1 = 3
|
||||
elif pos == [0, 1, 0, 0]:
|
||||
tmp = 0xD
|
||||
pos0 = 1
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 1, 0]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 0, 1]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [1, 1, 0, 0]:
|
||||
tmp = 0x4
|
||||
pos0 = 0
|
||||
pos1 = 1
|
||||
elif pos == [1, 0, 1, 0]:
|
||||
tmp = 0x8
|
||||
pos0 = 0
|
||||
pos1 = 2
|
||||
elif pos == [1, 0, 0, 1]:
|
||||
tmp = 0xC
|
||||
pos0 = 0
|
||||
pos1 = 3
|
||||
elif pos == [0, 1, 1, 0]:
|
||||
tmp = 0x9
|
||||
pos0 = 1
|
||||
pos1 = 2
|
||||
elif pos == [0, 1, 0, 1]:
|
||||
tmp = 0xD
|
||||
pos0 = 1
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 1, 1]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
# cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp)
|
||||
meta[row_idx, i] |= tmp << (j * 4)
|
||||
|
||||
a_compressed[row_idx, i * 16 + j * 2] = a[
|
||||
row_idx, i * 32 + j * 4 + pos0
|
||||
]
|
||||
a_compressed[row_idx, i * 16 + j * 2 + 1] = a[
|
||||
row_idx, i * 32 + j * 4 + pos1
|
||||
]
|
||||
|
||||
return
|
||||
|
||||
|
||||
# SparseUtils is used to generate sparse tensor
|
||||
# format torch.Tensor
|
||||
class SparseUtils:
|
||||
#!brief: SparseUtils is used to generate sparse tensor
|
||||
#!param: M: int, K: int, L: int, dtype: cutlass.DataType
|
||||
def __init__(self, M: int, K: int, L: int, dtype):
|
||||
self.M = M
|
||||
self.K = K
|
||||
self.L = L
|
||||
self.dtype = dtype
|
||||
self.meta_data = self._generate_meta_data_4_2()
|
||||
self._use_specific_meta_data = False
|
||||
|
||||
#!brief: cast cutlass.DataType to torch.Tensor
|
||||
def _get_type(self):
|
||||
if self.dtype == cutlass.Float16:
|
||||
return torch.float16
|
||||
elif self.dtype == cutlass.Float32:
|
||||
return torch.float32
|
||||
elif self.dtype == cutlass.Int8:
|
||||
return torch.int8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {self.dtype}")
|
||||
|
||||
def _generate_meta_data_4_2(self):
|
||||
# metadata for 4:2 sparse will in range( 4,8,9,c,d,e)
|
||||
# represents
|
||||
# 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4
|
||||
# 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8
|
||||
# 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c
|
||||
# 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9
|
||||
# 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d
|
||||
# 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e
|
||||
meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
|
||||
# 4:2 sparse, so each chunk is 4 elements, map to 4 bits
|
||||
K_NumChunk = self.K // 4
|
||||
meta_data = np.random.choice(
|
||||
meta_value, size=(self.M, K_NumChunk), replace=True
|
||||
)
|
||||
meta_data = torch.from_numpy(
|
||||
np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk)
|
||||
)
|
||||
return meta_data
|
||||
|
||||
#!brief: pack meta data
|
||||
def _pack_meta_data(self):
|
||||
tmp = []
|
||||
K_NumChunk = self.K // 4
|
||||
for i in range(self.M):
|
||||
for j in range(K_NumChunk // 8):
|
||||
v = 0
|
||||
for k in range(8):
|
||||
vv = int(self.meta_data[i, j * 8 + k] & 0xF)
|
||||
tt = vv << (k * 4)
|
||||
v = v | tt
|
||||
tmp.append(v)
|
||||
# debug print
|
||||
# print([hex(vt) for vt in tmp])
|
||||
result = torch.from_numpy(
|
||||
np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8)
|
||||
)
|
||||
return result
|
||||
|
||||
#!brief: use specific meta data
|
||||
def use_specific_meta_data(self, meta_data: torch.Tensor = None):
|
||||
if meta_data is not None:
|
||||
self.meta_data = meta_data
|
||||
self._use_specific_meta_data = True
|
||||
|
||||
#!brief: generate sparse tensor with tensor
|
||||
#!param: a: torch.Tensor
|
||||
#!param: run_on_cpu: bool
|
||||
#!return: torch.Tensor
|
||||
def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu):
|
||||
if run_on_cpu:
|
||||
if a.device.type != "cpu":
|
||||
raise ValueError("a must be on cpu")
|
||||
return self.__generate_sparse_tensor_cpu(a)
|
||||
else:
|
||||
if a.device.type != "cuda":
|
||||
raise ValueError("a must be on cuda")
|
||||
a_tensor = from_dlpack(a)
|
||||
packed_meta_data = self._pack_meta_data()
|
||||
meta_tensor = from_dlpack(packed_meta_data.cuda())
|
||||
self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor)
|
||||
return a
|
||||
|
||||
#!brief: generate sparse tensor
|
||||
#!param: run_on_cpu: bool
|
||||
#!return: torch.Tensor
|
||||
def generate_4_2_sparse_tensor(self, run_on_cpu):
|
||||
dtype = self._get_type()
|
||||
a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype)
|
||||
if run_on_cpu:
|
||||
return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu)
|
||||
else:
|
||||
return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu)
|
||||
|
||||
#!brief: generate sparse tensor on cpu
|
||||
#!param: a: torch.Tensor
|
||||
#!return: torch.Tensor
|
||||
def __generate_sparse_tensor_cpu(self, a):
|
||||
if not self._use_specific_meta_data:
|
||||
for m in range(self.M):
|
||||
for k in range(0, self.K, 4):
|
||||
# random choose 2 zero positions
|
||||
zero_indices = torch.randperm(4)[:2]
|
||||
a[m, k + zero_indices[0]] = 0
|
||||
a[m, k + zero_indices[1]] = 0
|
||||
return a
|
||||
else:
|
||||
# use specific meta data
|
||||
tensor_mask = []
|
||||
for i in range(self.M):
|
||||
for j in range(self.K // 4):
|
||||
meta_val = self.meta_data[i, j]
|
||||
tmp = []
|
||||
if meta_val == 0x4:
|
||||
tmp = [1, 1, 0, 0]
|
||||
elif meta_val == 0x8:
|
||||
tmp = [1, 0, 1, 0]
|
||||
elif meta_val == 0xC:
|
||||
tmp = [1, 0, 0, 1]
|
||||
elif meta_val == 0x9:
|
||||
tmp = [0, 1, 1, 0]
|
||||
elif meta_val == 0xD:
|
||||
tmp = [0, 1, 0, 1]
|
||||
elif meta_val == 0xE:
|
||||
tmp = [0, 0, 1, 1]
|
||||
tensor_mask.extend(tmp)
|
||||
a = torch.reshape(a, (-1,))
|
||||
mask = torch.tensor(tensor_mask)
|
||||
a = a * mask
|
||||
a = torch.reshape(a, (self.M, self.K))
|
||||
return a
|
||||
|
||||
@cute.jit
|
||||
def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor):
|
||||
"""Generate a sparse tensor from a dense tensor using metadata"""
|
||||
assert a.shape[0] == self.M and a.shape[1] == self.K
|
||||
assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.kernel(a, meta).launch(grid=grid, block=block)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(self, a: cute.Tensor, meta: cute.Tensor):
|
||||
"""Apply sparsity mask to input tensor using metadata"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
|
||||
# each thread process 1 ro
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
# each thread process 1 row
|
||||
if row_idx < self.M:
|
||||
# iterate over each chunk(32 elements)
|
||||
for i in range(meta_idx):
|
||||
meta_val = meta[(row_idx, i)]
|
||||
# iterate over each sparse pattern(4 elements)
|
||||
for j in range(8):
|
||||
meta_row = (meta_val >> (j * 4)) & 0xF
|
||||
idx0 = meta_row & 0x3
|
||||
idx1 = (meta_row >> 2) & 0x3
|
||||
r_id0 = 0
|
||||
r_id1 = 0
|
||||
# r_id is the idx that value is 0
|
||||
if idx0 >= 2 and idx1 >= 2:
|
||||
r_id0 = 0
|
||||
r_id1 = 1
|
||||
elif idx0 <= 1 and idx1 <= 1:
|
||||
r_id0 = 2
|
||||
r_id1 = 3
|
||||
else:
|
||||
r_id0 = idx0 ^ 0b1
|
||||
r_id1 = idx1 ^ 0b1
|
||||
row_id0 = r_id0 + i * 32 + j * 4
|
||||
row_id1 = r_id1 + i * 32 + j * 4
|
||||
a[row_idx, row_id0] = self.dtype(0.0)
|
||||
a[row_idx, row_id1] = self.dtype(0.0)
|
||||
return
|
||||
Reference in New Issue
Block a user