Files
cutlass/examples/python/CuTeDSL/utils/sparse_utils.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

458 lines
17 KiB
Python

import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch
@cute.jit
def print_tensor_dlpack(src: cute.Tensor):
print(src)
cute.print_tensor(src)
# Sparse emulation
class SparseEmulation:
def __init__(self, M: int, N: int, K: int, L: int):
self.M = M
self.N = N
self.K = K
self.L = L
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
"""Sparse emulation"""
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.kernel(a, b, d, e).launch(grid=grid, block=block)
return
@cute.kernel
def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
"""CUDA kernel to emulate sparse tensor core"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
if row_idx < self.M:
# each thread process 1 row
for col in range(self.N):
# each meta_idx stands for 32 elements
for e_idx in range(meta_idx):
meta_val = e[(row_idx, e_idx)]
for k in range(8):
# each k stands for 4 elements
meta_row = (meta_val >> (k * 4)) & 0xF
idx0 = meta_row & 0x3
idx1 = (meta_row >> 2) & 0x3
# calculate the idx in b tensor which has value in A tensor
km = e_idx * 16 + k * 2
km_1 = km + 1
kn = e_idx * 32 + k * 4 + idx0
kn_1 = e_idx * 32 + k * 4 + idx1
d[row_idx, col] += a[row_idx, km] * b[col, kn]
d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1]
return
# Compressor
# compress a sparse tensor to a dense tensor && generate metadata
class Compressor:
def __init__(self, M: int, K: int, L: int):
self.M = M
self.K = K
self.L = L
self.pos_map = {
0x4: [0, 1],
0x8: [0, 2],
0xC: [0, 3],
0x9: [1, 2],
0xD: [1, 3],
0xE: [2, 3],
}
@cute.jit
def _init__(self, a: cute.Tensor):
self.__init__(a.shape[0], a.shape[1], a.shape[2])
def compress(self, a, a_compressed, meta, run_on_cpu: bool):
if run_on_cpu:
if a.device.type != "cpu":
raise ValueError("a must be on cpu")
return self.__compress_on_cpu(a, a_compressed, meta)
else:
if a.device.type != "cuda":
raise ValueError("a must be on cuda")
return self.__compress_on_cuda(a, a_compressed, meta)
def __compress_on_cpu(self, a, a_compressed, meta):
"""
compress the tensor on cpu
# Convert to 4-bit metadata value
# The metadata value represents which 2 elements are non-zero
# 0x4: [1,1,0,0] - first two elements are non-zero
# 0x8: [1,0,1,0] - first and third elements are non-zero
# 0xC: [1,0,0,1] - first and fourth elements are non-zero
# 0x9: [0,1,1,0] - second and third elements are non-zero
# 0xD: [0,1,0,1] - second and fourth elements are non-zero
# 0xE: [0,0,1,1] - third and fourth elements are non-zero
# special case:
# [0,0,0,0] == [0,0,1,1]
# [1,0,0,0] == [1,0,0,1]
# [0,1,0,0] == [0,1,0,1]
# [0,0,1,0] == [0,0,1,1]
# [0,0,0,1] == [0,0,1,1]
"""
M, K = a.shape
assert a_compressed.shape == (
M,
K // 2,
), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}"
assert meta.shape == (
M,
K // 4 // 8,
), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}"
for m in range(M):
k_meta = 0
for k in range(0, K, 4):
chunk = a[m, k : k + 4]
non_zero_indices = torch.nonzero(chunk).squeeze()
meta_val = 0xE
if torch.equal(non_zero_indices, torch.tensor([0, 1])):
meta_val = 0x4
elif torch.equal(non_zero_indices, torch.tensor([0, 2])):
meta_val = 0x8
elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal(
non_zero_indices, torch.tensor(0)
):
meta_val = 0xC
elif torch.equal(non_zero_indices, torch.tensor([1, 2])):
meta_val = 0x9
elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal(
non_zero_indices, torch.tensor(1)
):
meta_val = 0xD
elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal(
non_zero_indices, torch.tensor(2)
):
meta_val = 0xE
elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal(
non_zero_indices, torch.tensor(3)
):
meta_val = 0xE
else:
raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}")
meta_idx = k // 4 // 8
meta_bit_pos = (k // 4) % 8
if k_meta == meta_idx:
k_meta = meta_idx + 1
meta[m, meta_idx] = 0
meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4)
compressed_idx = k // 2
index = self.pos_map[meta_val]
a_compressed[m, compressed_idx] = chunk[index[0]]
a_compressed[m, compressed_idx + 1] = chunk[index[1]]
def __compress_on_cuda(self, a, a_compressed, meta):
"""
compress the tensor on cuda
"""
a_tensor = from_dlpack(a)
a_compressed_tensor = from_dlpack(a_compressed)
meta_tensor = from_dlpack(meta)
self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor)
return
@cute.jit
def compress_on_cuda_impl(
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
):
"""Compress the input tensor using the metadata"""
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block)
@cute.kernel
def compressor_impl(
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
):
"""CUDA kernel to compress the tensor"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
m = a.shape[0]
k = a.shape[1]
# each thread process 1 row
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
if row_idx < self.M:
# each meta_idx stands for 32 elements
for i in range(meta_idx):
meta[row_idx, i] = 0
# each k stands for 4 elements
for j in range(8):
val = a[row_idx, i * 32 + j * 4]
val_1 = a[row_idx, i * 32 + j * 4 + 1]
val_2 = a[row_idx, i * 32 + j * 4 + 2]
val_3 = a[row_idx, i * 32 + j * 4 + 3]
value_idx = 0
value_idx_1 = 0
value_idx_2 = 0
value_idx_3 = 0
pos0 = 0
pos1 = 0
if val != 0:
value_idx = 1
pos0 = 0
if val_1 != 0:
value_idx_1 = 1
if val_2 != 0:
value_idx_2 = 1
if val_3 != 0:
value_idx_3 = 1
pos = [value_idx, value_idx_1, value_idx_2, value_idx_3]
tmp = 0
if pos == [0, 0, 0, 0]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [1, 0, 0, 0]:
tmp = 0xC
pos0 = 0
pos1 = 3
elif pos == [0, 1, 0, 0]:
tmp = 0xD
pos0 = 1
pos1 = 3
elif pos == [0, 0, 1, 0]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [0, 0, 0, 1]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [1, 1, 0, 0]:
tmp = 0x4
pos0 = 0
pos1 = 1
elif pos == [1, 0, 1, 0]:
tmp = 0x8
pos0 = 0
pos1 = 2
elif pos == [1, 0, 0, 1]:
tmp = 0xC
pos0 = 0
pos1 = 3
elif pos == [0, 1, 1, 0]:
tmp = 0x9
pos0 = 1
pos1 = 2
elif pos == [0, 1, 0, 1]:
tmp = 0xD
pos0 = 1
pos1 = 3
elif pos == [0, 0, 1, 1]:
tmp = 0xE
pos0 = 2
pos1 = 3
# cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp)
meta[row_idx, i] |= tmp << (j * 4)
a_compressed[row_idx, i * 16 + j * 2] = a[
row_idx, i * 32 + j * 4 + pos0
]
a_compressed[row_idx, i * 16 + j * 2 + 1] = a[
row_idx, i * 32 + j * 4 + pos1
]
return
# SparseUtils is used to generate sparse tensor
# format torch.Tensor
class SparseUtils:
#!brief: SparseUtils is used to generate sparse tensor
#!param: M: int, K: int, L: int, dtype: cutlass.DataType
def __init__(self, M: int, K: int, L: int, dtype):
self.M = M
self.K = K
self.L = L
self.dtype = dtype
self.meta_data = self._generate_meta_data_4_2()
self._use_specific_meta_data = False
#!brief: cast cutlass.DataType to torch.Tensor
def _get_type(self):
if self.dtype == cutlass.Float16:
return torch.float16
elif self.dtype == cutlass.Float32:
return torch.float32
elif self.dtype == cutlass.Int8:
return torch.int8
else:
raise ValueError(f"Unsupported dtype: {self.dtype}")
def _generate_meta_data_4_2(self):
# metadata for 4:2 sparse will in range( 4,8,9,c,d,e)
# represents
# 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4
# 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8
# 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c
# 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9
# 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d
# 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e
meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
# 4:2 sparse, so each chunk is 4 elements, map to 4 bits
K_NumChunk = self.K // 4
meta_data = np.random.choice(
meta_value, size=(self.M, K_NumChunk), replace=True
)
meta_data = torch.from_numpy(
np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk)
)
return meta_data
#!brief: pack meta data
def _pack_meta_data(self):
tmp = []
K_NumChunk = self.K // 4
for i in range(self.M):
for j in range(K_NumChunk // 8):
v = 0
for k in range(8):
vv = int(self.meta_data[i, j * 8 + k] & 0xF)
tt = vv << (k * 4)
v = v | tt
tmp.append(v)
# debug print
# print([hex(vt) for vt in tmp])
result = torch.from_numpy(
np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8)
)
return result
#!brief: use specific meta data
def use_specific_meta_data(self, meta_data: torch.Tensor = None):
if meta_data is not None:
self.meta_data = meta_data
self._use_specific_meta_data = True
#!brief: generate sparse tensor with tensor
#!param: a: torch.Tensor
#!param: run_on_cpu: bool
#!return: torch.Tensor
def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu):
if run_on_cpu:
if a.device.type != "cpu":
raise ValueError("a must be on cpu")
return self.__generate_sparse_tensor_cpu(a)
else:
if a.device.type != "cuda":
raise ValueError("a must be on cuda")
a_tensor = from_dlpack(a)
packed_meta_data = self._pack_meta_data()
meta_tensor = from_dlpack(packed_meta_data.cuda())
self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor)
return a
#!brief: generate sparse tensor
#!param: run_on_cpu: bool
#!return: torch.Tensor
def generate_4_2_sparse_tensor(self, run_on_cpu):
dtype = self._get_type()
a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype)
if run_on_cpu:
return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu)
else:
return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu)
#!brief: generate sparse tensor on cpu
#!param: a: torch.Tensor
#!return: torch.Tensor
def __generate_sparse_tensor_cpu(self, a):
if not self._use_specific_meta_data:
for m in range(self.M):
for k in range(0, self.K, 4):
# random choose 2 zero positions
zero_indices = torch.randperm(4)[:2]
a[m, k + zero_indices[0]] = 0
a[m, k + zero_indices[1]] = 0
return a
else:
# use specific meta data
tensor_mask = []
for i in range(self.M):
for j in range(self.K // 4):
meta_val = self.meta_data[i, j]
tmp = []
if meta_val == 0x4:
tmp = [1, 1, 0, 0]
elif meta_val == 0x8:
tmp = [1, 0, 1, 0]
elif meta_val == 0xC:
tmp = [1, 0, 0, 1]
elif meta_val == 0x9:
tmp = [0, 1, 1, 0]
elif meta_val == 0xD:
tmp = [0, 1, 0, 1]
elif meta_val == 0xE:
tmp = [0, 0, 1, 1]
tensor_mask.extend(tmp)
a = torch.reshape(a, (-1,))
mask = torch.tensor(tensor_mask)
a = a * mask
a = torch.reshape(a, (self.M, self.K))
return a
@cute.jit
def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor):
"""Generate a sparse tensor from a dense tensor using metadata"""
assert a.shape[0] == self.M and a.shape[1] == self.K
assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.kernel(a, meta).launch(grid=grid, block=block)
@cute.kernel
def kernel(self, a: cute.Tensor, meta: cute.Tensor):
"""Apply sparsity mask to input tensor using metadata"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
# each thread process 1 ro
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
# each thread process 1 row
if row_idx < self.M:
# iterate over each chunk(32 elements)
for i in range(meta_idx):
meta_val = meta[(row_idx, i)]
# iterate over each sparse pattern(4 elements)
for j in range(8):
meta_row = (meta_val >> (j * 4)) & 0xF
idx0 = meta_row & 0x3
idx1 = (meta_row >> 2) & 0x3
r_id0 = 0
r_id1 = 0
# r_id is the idx that value is 0
if idx0 >= 2 and idx1 >= 2:
r_id0 = 0
r_id1 = 1
elif idx0 <= 1 and idx1 <= 1:
r_id0 = 2
r_id1 = 3
else:
r_id0 = idx0 ^ 0b1
r_id1 = idx1 ^ 0b1
row_id0 = r_id0 + i * 32 + j * 4
row_id1 = r_id1 + i * 32 + j * 4
a[row_idx, row_id0] = self.dtype(0.0)
a[row_idx, row_id1] = self.dtype(0.0)
return