v4.3 update. (#2709)

* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
Junkai-Wu
2025-10-22 02:26:30 +08:00
committed by GitHub
parent e6e2cc29f5
commit b1d6e2c9b3
244 changed files with 59272 additions and 10455 deletions

View File

@ -0,0 +1,975 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from typing import Tuple, Optional
import cutlass
from cutlass.cute.typing import Boolean
from cutlass.cutlass_dsl import (
Int32,
Float32,
min,
extract_mlir_values,
new_from_mlir_values,
)
from cutlass.utils.hardware_info import HardwareInfo
from cutlass.utils import WorkTileInfo
import cutlass.cute as cute
##############################################################################
# Fmha static tile scheduler
##############################################################################
class FmhaStaticTileSchedulerParams:
"""A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler.
This class holds the configuration parameters needed to initialize and configure
the tile scheduler for FMHA operations.
:ivar is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:ivar problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
"""
def __init__(
self,
is_persistent: bool,
problem_shape_mbh: cute.Shape,
*,
loc=None,
ip=None,
):
"""
Initializes the FmhaStaticTileSchedulerParams with the given parameters.
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:param problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
"""
self.is_persistent = is_persistent
self.problem_shape_mbh = problem_shape_mbh
self._loc = loc
self._ip = ip
def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [self.problem_shape_mbh]:
obj_values = extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values
def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip([self.problem_shape_mbh], self._values_pos):
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return FmhaStaticTileSchedulerParams(
self.is_persistent, *(tuple(obj_list)), loc=self._loc
)
class FmhaStaticTileScheduler:
"""A static tile scheduler for FMHA (Fused Multi-Head Attention) operations.
This class manages the scheduling of work tiles for FMHA kernels, supporting
both persistent and non-persistent kernel modes. It tracks the current work
position and advances through the problem space efficiently.
:ivar _params: Scheduler parameters.
:type _params: FmhaStaticTileSchedulerParams
:ivar _blk_coord: Block coordinates.
:type _blk_coord: cute.Coord
:ivar _grid_shape: Grid shape for the kernel.
:type _grid_shape: cute.Shape
:ivar _is_persistent: Whether to use persistent kernel mode.
:type _is_persistent: bool
:ivar _current_work_linear_idx: Current linear work index.
:type _current_work_linear_idx: Int32
:ivar _problem_shape_mbh: Problem shape in (M, B, H) format.
:type _problem_shape_mbh: cute.Layout
:ivar _num_blocks: Number of blocks in the problem.
:type _num_blocks: Int32
:ivar _is_first_block: Whether this is the first block.
:type _is_first_block: bool
:ivar num_persistent_sm: Number of persistent SMs.
:type num_persistent_sm: Int32
"""
def __init__(
self,
params: FmhaStaticTileSchedulerParams,
current_work_linear_idx: Int32,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
*,
loc=None,
ip=None,
):
"""
Initializes the FmhaStaticTileScheduler with the given parameters.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:param current_work_linear_idx: Current linear work index.
:type current_work_linear_idx: Int32
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param grid_shape: Grid shape for the kernel.
:type grid_shape: cute.Shape
"""
self._params = params
self._blk_coord = blk_coord
self._grid_shape = grid_shape
self._is_persistent = params.is_persistent
self._current_work_linear_idx = current_work_linear_idx
self._problem_shape_mbh = cute.make_layout(
params.problem_shape_mbh, loc=loc, ip=ip
)
self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip)
self._is_first_block = True
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
self._loc = loc
self._ip = ip
# called by host
@staticmethod
def get_grid_shape(
params: FmhaStaticTileSchedulerParams,
*,
loc=None,
ip=None,
) -> cute.Shape:
"""
Determine the grid shape for the FMHA kernel.
For persistent kernels, the grid shape is limited by the number of SMs
(Streaming Multiprocessors) available on the device. For non-persistent
kernels, the grid shape matches the problem shape.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:return: Grid shape as (M, B, H) tuple.
:rtype: cute.Shape
"""
if params.is_persistent:
hardware_info = HardwareInfo()
sm_count = hardware_info.get_device_multiprocessor_count()
return (
min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)),
1,
1,
)
else:
return params.problem_shape_mbh
@staticmethod
def check_valid_work_for_seqlen_q(
q_tiler: int,
current_idx: Int32,
seqlen_q: Int32,
) -> Boolean:
"""
Check if the current work index is valid for the given query sequence length.
This method verifies that the current work tile index multiplied by the
query tiler size is within the bounds of the query sequence length.
:param q_tiler: Query tiler size.
:type q_tiler: int
:param current_idx: Current work index.
:type current_idx: Int32
:param seqlen_q: Query sequence length.
:type seqlen_q: Int32
:return: True if the work is valid, False otherwise.
:rtype: Boolean
"""
return current_idx * q_tiler < seqlen_q
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
"""
Get information about the current work tile.
Determines if the current work is valid and computes the tile coordinates
based on whether the kernel is persistent or non-persistent.
:return: WorkTileInfo containing tile coordinates and validity flag.
:rtype: WorkTileInfo
"""
is_valid = (
self._current_work_linear_idx < self._num_blocks
if self._is_persistent
else self._is_first_block
)
blk_coord = (0, 0, 0)
if self._is_persistent:
blk_coord = self._problem_shape_mbh.get_hier_coord(
self._current_work_linear_idx, loc=loc, ip=ip
)
else:
blk_coord = self._blk_coord
# cur_tile_coord is (mid, 0, (bid, hid))
cur_tile_coord = (
blk_coord[0],
0,
(blk_coord[1], blk_coord[2]),
)
return WorkTileInfo(cur_tile_coord, is_valid)
def initial_work_tile_info(self, *, loc=None, ip=None):
"""
Get the initial work tile information.
:return: Initial WorkTileInfo.
:rtype: WorkTileInfo
"""
return self.get_current_work(loc=loc, ip=ip)
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
"""
Advance to the next work tile.
For persistent kernels, advances by the number of persistent SMs.
For non-persistent kernels, marks that the first block has been processed.
:param advance_count: Number of steps to advance (default: 1).
:type advance_count: int
"""
if self._is_persistent:
self._current_work_linear_idx += advance_count * self.num_persistent_sm
self._is_first_block = False
def __extract_mlir_values__(self):
values = extract_mlir_values(self._params)
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self._blk_coord))
values.extend(extract_mlir_values(self._grid_shape))
return values
def __new_from_mlir_values__(self, values):
assert len(values) == 10
new_params = new_from_mlir_values(self._params, values[0:3])
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[3]]
)
new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7])
new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:])
return FmhaStaticTileScheduler(
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
)
def create_fmha_static_tile_scheduler(
params: FmhaStaticTileSchedulerParams,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
) -> FmhaStaticTileScheduler:
"""
Create a new FMHA static tile scheduler.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param grid_shape: Grid shape.
:type grid_shape: cute.Shape
:return: New FmhaStaticTileScheduler instance.
:rtype: FmhaStaticTileScheduler
"""
return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
def create_fmha_static_tile_scheduler_params(
is_persistent: bool,
problem_shape_mbh: cute.Shape,
) -> FmhaStaticTileSchedulerParams:
"""
Create FMHA static tile scheduler parameters.
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:param problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
:return: New FmhaStaticTileSchedulerParams instance.
:rtype: FmhaStaticTileSchedulerParams
"""
return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh)
def compute_grid(
o_shape: cute.Shape,
cta_tiler: Tuple[int, int, int],
is_persistent: bool,
) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]:
"""
Compute grid parameters for FMHA operation.
This function calculates the appropriate grid shape and scheduler parameters
based on the output tensor shape, CTA (Cooperative Thread Array) tiler,
and whether to use persistent kernel mode.
The output tensor o has shape (s, d, ((h_r, h_k), b)) where:
- s: sequence length
- d: head dimension
- h_r: number of heads for query
- h_k: number of heads for key
- b: batch size
:param o_shape: Output tensor shape for grid computation.
:type o_shape: cute.Shape
:param cta_tiler: CTA tiler dimensions (M, N, K).
:type cta_tiler: Tuple[int, int, int]
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:return: Tuple of (scheduler_params, grid_shape).
:rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]
"""
tile_sched_params = create_fmha_static_tile_scheduler_params(
is_persistent,
(
cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]),
cute.size(o_shape[2][0]),
cute.size(o_shape[2][1]),
),
)
grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params)
return tile_sched_params, grid
##############################################################################
# Fused Mask
##############################################################################
class MaskEnum(enum.Enum):
"""Enumeration of mask types for FMHA operations.
- RESIDUAL_MASK: Residual mask for handling variable sequence lengths
- WINDOW_MASK: Window mask for attention which also includes causal and no mask
- WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k
- WINDOW_MASK_BWD: Window mask for backward pass
- WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k
"""
RESIDUAL_MASK = enum.auto()
RESIDUAL_MASK_BWD = enum.auto()
WINDOW_MASK = enum.auto()
WINDOW_MASK_INFERENCE = enum.auto()
WINDOW_MASK_BWD = enum.auto()
WINDOW_MASK_BWD_INFERENCE = enum.auto()
class FusedMask:
"""A fused mask implementation for FMHA operations.
This class handles different types of attention masks including no mask,
residual mask for variable sequence lengths, and causal mask for
autoregressive attention patterns.
The class provides methods to:
- Calculate trip counts for different mask types
- Apply masks to attention scores
- Handle masked and unmasked trip calculations
"""
def get_trip_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of trips needed for the current block.
The trip count depends on the mask type and the block coordinates.
For causal masks, it considers the autoregressive constraint.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of trips needed.
:rtype: Int32
"""
result = 0
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
if cutlass.const_expr(mask_type == MaskEnum.RESIDUAL_MASK):
result = cute.ceil_div(seqlen_k, tile_shape[1])
if cutlass.const_expr(mask_type is MaskEnum.RESIDUAL_MASK_BWD):
result = cute.ceil_div(seqlen_q, tile_shape[0])
if cutlass.const_expr(
mask_type == MaskEnum.WINDOW_MASK
or mask_type == MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_right is None):
result = cute.ceil_div(seqlen_k, tile_shape[1])
else:
max_idx_q = (blk_coord[0] + 1) * tile_shape[0]
idx_k = max_idx_q + offset + window_size_right
tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1])
max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
result = min(max_blocks_k, tmp_blocks_k)
if cutlass.const_expr(
mask_type == MaskEnum.WINDOW_MASK_BWD
or mask_type == MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_left is None):
result = cute.ceil_div(seqlen_q, tile_shape[0])
else:
max_idx_k = (blk_coord[1] + 1) * tile_shape[1]
idx_k = max_idx_k + offset + window_size_left
tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0])
max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0])
result = min(max_blocks_q, tmp_blocks_q)
start_block = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
result = result - start_block
return result
@cute.jit
def get_trip_start(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Get the start of the trip for the current block.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
"""
result = 0
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_left is not None):
min_idx_q = blk_coord[0] * tile_shape[0]
idx_k = min_idx_q + offset - window_size_left
tmp_blocks_k = idx_k // tile_shape[1]
result = max(tmp_blocks_k, result)
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_BWD
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_k = blk_coord[1] * tile_shape[1]
idx_q = min_idx_k + offset - window_size_right
tmp_blocks_q = idx_q // tile_shape[0]
result = max(tmp_blocks_q, result)
return result
@cute.jit
def get_leading_mask_id(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Tuple[Int32, Int32]:
"""
Get the begin and end tile idx for the leading mask.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Tuple of (begin, end) tile idx for the leading mask.
:rtype: Tuple[Int32, Int32]
"""
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
leading_mask_begin = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trip_count = FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
leading_mask_end = leading_mask_begin
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_left is not None):
min_idx_q = (
(blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left
)
leading_mask_end = min(
cute.ceil_div(min_idx_q, tile_shape[1]) - 1,
trip_count + leading_mask_begin - 1,
)
else:
leading_mask_end = leading_mask_begin - 1
elif cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_BWD
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_k = (
(blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right
)
leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1
else:
leading_mask_end = leading_mask_begin - 1
return leading_mask_begin, leading_mask_end
@cute.jit
def get_trailing_mask_id(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Tuple[Optional[Int32], Optional[Int32]]:
"""
Get the begin and end tile idx for the trailing mask.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Tuple of (begin, end) tile idx for the trailing mask.
:rtype: Tuple[Int32, Int32]
"""
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
trip_start = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trip_count = FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trailing_mask_begin, trailing_mask_end = None, None
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right
trailing_mask_begin = min(
min_idx_q // tile_shape[1], trip_count + trip_start - 1
)
trailing_mask_end = trip_count + trip_start - 1
else:
# last tile, we always apply mask on it regardless whether it's a residual tile
trailing_mask_begin = trip_count + trip_start - 1
trailing_mask_end = trip_count + trip_start - 1
else:
if cutlass.const_expr(window_size_left is not None):
min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1
max_idx_k = (
(blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left
)
trailing_mask_begin = min(
cute.ceil_div(min_idx_k, tile_shape[0]) - 1,
trip_count + trip_start - 1,
)
trailing_mask_end = min(
cute.ceil_div(max_idx_k, tile_shape[0]) - 1,
trip_count + trip_start - 1,
)
else:
# last tile, we always apply mask on it regardless whether it's a residual tile
trailing_mask_begin = trip_count + trip_start - 1
trailing_mask_end = trip_count + trip_start - 1
return trailing_mask_begin, trailing_mask_end
@cute.jit
def get_masked_leading_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of masked trips for the leading mask.
This is used for blocks that need special handling due to masking.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of masked trips.
:rtype: Int32
"""
result = 0
if cutlass.const_expr(
mask_type is not MaskEnum.RESIDUAL_MASK
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
):
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
result = max(leading_mask_end - leading_mask_begin + 1, 0)
return result
@cute.jit
def get_masked_trailing_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
rem_count: Optional[Int32] = 0,
) -> Int32:
"""
Calculate the number of masked trips for the trailing mask.
This is used for blocks that need special handling due to masking.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:param rem_count: Remaining count from previous calculations.
:type rem_count: Int32
:return: Number of masked trips.
:rtype: Int32
"""
result = 0
if cutlass.const_expr(
mask_type is not MaskEnum.RESIDUAL_MASK
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
):
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
if cutlass.const_expr(
trailing_mask_begin is not None and trailing_mask_end is not None
):
if trailing_mask_begin <= leading_mask_end:
result = max(trailing_mask_end - leading_mask_end, 0)
else:
result = max(trailing_mask_end - trailing_mask_begin + 1, 0)
else:
if seqlen_k % tile_shape[1] != 0:
result = 1
else:
result = 0
return result + rem_count
@cute.jit
def get_unmasked_trip_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of unmasked trips for the current block.
This represents the number of trips that don't require special
masking treatment.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of unmasked trips.
:rtype: Int32
"""
result = (
FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
- FusedMask.get_masked_leading_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
- FusedMask.get_masked_trailing_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
0,
)
)
return result
@cute.jit
def apply_mask(
mask_type: MaskEnum,
acc_qk: cute.Tensor,
index_qk: cute.Tensor,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
index_transform: cutlass.Constexpr = lambda index_q, index_k: (
index_q,
index_k,
),
):
"""
Apply the appropriate mask to the attention scores.
This method modifies the attention scores (acc_qk) based on the mask type
and the positions in the index tensor.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param acc_qk: Accumulated QK attention scores tensor.
:type acc_qk: cute.Tensor
:param index_qk: Index tensor containing position information.
:type index_qk: cute.Tensor
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Optional[int]
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[int]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[int]
"""
tidx, tidy, tidx = cute.arch.thread_idx()
offset = 0
offset = (
seqlen_k - seqlen_q
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_INFERENCE
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
)
else 0
)
for i in cutlass.range_constexpr(cute.size(acc_qk)):
index_q, index_k = index_transform(*index_qk[i])
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
if cutlass.const_expr(window_size_left is None):
if index_q + offset + window_size_right < index_k:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
elif cutlass.const_expr(window_size_right is None):
if index_q + offset - window_size_left > index_k:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
else:
max_K_index = min(index_q + offset + window_size_right, seqlen_k)
min_K_index = max(0, index_q + offset - window_size_left)
if index_k > max_K_index or index_k < min_K_index:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
if cutlass.const_expr(
mask_type == MaskEnum.RESIDUAL_MASK
or mask_type == MaskEnum.RESIDUAL_MASK_BWD
):
if index_k >= seqlen_k or index_q >= seqlen_q:
acc_qk[i] = -Float32.inf

View File

@ -0,0 +1,457 @@
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch
@cute.jit
def print_tensor_dlpack(src: cute.Tensor):
print(src)
cute.print_tensor(src)
# Sparse emulation
class SparseEmulation:
def __init__(self, M: int, N: int, K: int, L: int):
self.M = M
self.N = N
self.K = K
self.L = L
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
"""Sparse emulation"""
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.kernel(a, b, d, e).launch(grid=grid, block=block)
return
@cute.kernel
def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
"""CUDA kernel to emulate sparse tensor core"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
if row_idx < self.M:
# each thread process 1 row
for col in range(self.N):
# each meta_idx stands for 32 elements
for e_idx in range(meta_idx):
meta_val = e[(row_idx, e_idx)]
for k in range(8):
# each k stands for 4 elements
meta_row = (meta_val >> (k * 4)) & 0xF
idx0 = meta_row & 0x3
idx1 = (meta_row >> 2) & 0x3
# calculate the idx in b tensor which has value in A tensor
km = e_idx * 16 + k * 2
km_1 = km + 1
kn = e_idx * 32 + k * 4 + idx0
kn_1 = e_idx * 32 + k * 4 + idx1
d[row_idx, col] += a[row_idx, km] * b[col, kn]
d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1]
return
# Compressor
# compress a sparse tensor to a dense tensor && generate metadata
class Compressor:
def __init__(self, M: int, K: int, L: int):
self.M = M
self.K = K
self.L = L
self.pos_map = {
0x4: [0, 1],
0x8: [0, 2],
0xC: [0, 3],
0x9: [1, 2],
0xD: [1, 3],
0xE: [2, 3],
}
@cute.jit
def _init__(self, a: cute.Tensor):
self.__init__(a.shape[0], a.shape[1], a.shape[2])
def compress(self, a, a_compressed, meta, run_on_cpu: bool):
if run_on_cpu:
if a.device.type != "cpu":
raise ValueError("a must be on cpu")
return self.__compress_on_cpu(a, a_compressed, meta)
else:
if a.device.type != "cuda":
raise ValueError("a must be on cuda")
return self.__compress_on_cuda(a, a_compressed, meta)
def __compress_on_cpu(self, a, a_compressed, meta):
"""
compress the tensor on cpu
# Convert to 4-bit metadata value
# The metadata value represents which 2 elements are non-zero
# 0x4: [1,1,0,0] - first two elements are non-zero
# 0x8: [1,0,1,0] - first and third elements are non-zero
# 0xC: [1,0,0,1] - first and fourth elements are non-zero
# 0x9: [0,1,1,0] - second and third elements are non-zero
# 0xD: [0,1,0,1] - second and fourth elements are non-zero
# 0xE: [0,0,1,1] - third and fourth elements are non-zero
# special case:
# [0,0,0,0] == [0,0,1,1]
# [1,0,0,0] == [1,0,0,1]
# [0,1,0,0] == [0,1,0,1]
# [0,0,1,0] == [0,0,1,1]
# [0,0,0,1] == [0,0,1,1]
"""
M, K = a.shape
assert a_compressed.shape == (
M,
K // 2,
), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}"
assert meta.shape == (
M,
K // 4 // 8,
), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}"
for m in range(M):
k_meta = 0
for k in range(0, K, 4):
chunk = a[m, k : k + 4]
non_zero_indices = torch.nonzero(chunk).squeeze()
meta_val = 0xE
if torch.equal(non_zero_indices, torch.tensor([0, 1])):
meta_val = 0x4
elif torch.equal(non_zero_indices, torch.tensor([0, 2])):
meta_val = 0x8
elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal(
non_zero_indices, torch.tensor(0)
):
meta_val = 0xC
elif torch.equal(non_zero_indices, torch.tensor([1, 2])):
meta_val = 0x9
elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal(
non_zero_indices, torch.tensor(1)
):
meta_val = 0xD
elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal(
non_zero_indices, torch.tensor(2)
):
meta_val = 0xE
elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal(
non_zero_indices, torch.tensor(3)
):
meta_val = 0xE
else:
raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}")
meta_idx = k // 4 // 8
meta_bit_pos = (k // 4) % 8
if k_meta == meta_idx:
k_meta = meta_idx + 1
meta[m, meta_idx] = 0
meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4)
compressed_idx = k // 2
index = self.pos_map[meta_val]
a_compressed[m, compressed_idx] = chunk[index[0]]
a_compressed[m, compressed_idx + 1] = chunk[index[1]]
def __compress_on_cuda(self, a, a_compressed, meta):
"""
compress the tensor on cuda
"""
a_tensor = from_dlpack(a)
a_compressed_tensor = from_dlpack(a_compressed)
meta_tensor = from_dlpack(meta)
self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor)
return
@cute.jit
def compress_on_cuda_impl(
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
):
"""Compress the input tensor using the metadata"""
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block)
@cute.kernel
def compressor_impl(
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
):
"""CUDA kernel to compress the tensor"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
m = a.shape[0]
k = a.shape[1]
# each thread process 1 row
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
if row_idx < self.M:
# each meta_idx stands for 32 elements
for i in range(meta_idx):
meta[row_idx, i] = 0
# each k stands for 4 elements
for j in range(8):
val = a[row_idx, i * 32 + j * 4]
val_1 = a[row_idx, i * 32 + j * 4 + 1]
val_2 = a[row_idx, i * 32 + j * 4 + 2]
val_3 = a[row_idx, i * 32 + j * 4 + 3]
value_idx = 0
value_idx_1 = 0
value_idx_2 = 0
value_idx_3 = 0
pos0 = 0
pos1 = 0
if val != 0:
value_idx = 1
pos0 = 0
if val_1 != 0:
value_idx_1 = 1
if val_2 != 0:
value_idx_2 = 1
if val_3 != 0:
value_idx_3 = 1
pos = [value_idx, value_idx_1, value_idx_2, value_idx_3]
tmp = 0
if pos == [0, 0, 0, 0]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [1, 0, 0, 0]:
tmp = 0xC
pos0 = 0
pos1 = 3
elif pos == [0, 1, 0, 0]:
tmp = 0xD
pos0 = 1
pos1 = 3
elif pos == [0, 0, 1, 0]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [0, 0, 0, 1]:
tmp = 0xE
pos0 = 2
pos1 = 3
elif pos == [1, 1, 0, 0]:
tmp = 0x4
pos0 = 0
pos1 = 1
elif pos == [1, 0, 1, 0]:
tmp = 0x8
pos0 = 0
pos1 = 2
elif pos == [1, 0, 0, 1]:
tmp = 0xC
pos0 = 0
pos1 = 3
elif pos == [0, 1, 1, 0]:
tmp = 0x9
pos0 = 1
pos1 = 2
elif pos == [0, 1, 0, 1]:
tmp = 0xD
pos0 = 1
pos1 = 3
elif pos == [0, 0, 1, 1]:
tmp = 0xE
pos0 = 2
pos1 = 3
# cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp)
meta[row_idx, i] |= tmp << (j * 4)
a_compressed[row_idx, i * 16 + j * 2] = a[
row_idx, i * 32 + j * 4 + pos0
]
a_compressed[row_idx, i * 16 + j * 2 + 1] = a[
row_idx, i * 32 + j * 4 + pos1
]
return
# SparseUtils is used to generate sparse tensor
# format torch.Tensor
class SparseUtils:
#!brief: SparseUtils is used to generate sparse tensor
#!param: M: int, K: int, L: int, dtype: cutlass.DataType
def __init__(self, M: int, K: int, L: int, dtype):
self.M = M
self.K = K
self.L = L
self.dtype = dtype
self.meta_data = self._generate_meta_data_4_2()
self._use_specific_meta_data = False
#!brief: cast cutlass.DataType to torch.Tensor
def _get_type(self):
if self.dtype == cutlass.Float16:
return torch.float16
elif self.dtype == cutlass.Float32:
return torch.float32
elif self.dtype == cutlass.Int8:
return torch.int8
else:
raise ValueError(f"Unsupported dtype: {self.dtype}")
def _generate_meta_data_4_2(self):
# metadata for 4:2 sparse will in range( 4,8,9,c,d,e)
# represents
# 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4
# 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8
# 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c
# 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9
# 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d
# 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e
meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
# 4:2 sparse, so each chunk is 4 elements, map to 4 bits
K_NumChunk = self.K // 4
meta_data = np.random.choice(
meta_value, size=(self.M, K_NumChunk), replace=True
)
meta_data = torch.from_numpy(
np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk)
)
return meta_data
#!brief: pack meta data
def _pack_meta_data(self):
tmp = []
K_NumChunk = self.K // 4
for i in range(self.M):
for j in range(K_NumChunk // 8):
v = 0
for k in range(8):
vv = int(self.meta_data[i, j * 8 + k] & 0xF)
tt = vv << (k * 4)
v = v | tt
tmp.append(v)
# debug print
# print([hex(vt) for vt in tmp])
result = torch.from_numpy(
np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8)
)
return result
#!brief: use specific meta data
def use_specific_meta_data(self, meta_data: torch.Tensor = None):
if meta_data is not None:
self.meta_data = meta_data
self._use_specific_meta_data = True
#!brief: generate sparse tensor with tensor
#!param: a: torch.Tensor
#!param: run_on_cpu: bool
#!return: torch.Tensor
def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu):
if run_on_cpu:
if a.device.type != "cpu":
raise ValueError("a must be on cpu")
return self.__generate_sparse_tensor_cpu(a)
else:
if a.device.type != "cuda":
raise ValueError("a must be on cuda")
a_tensor = from_dlpack(a)
packed_meta_data = self._pack_meta_data()
meta_tensor = from_dlpack(packed_meta_data.cuda())
self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor)
return a
#!brief: generate sparse tensor
#!param: run_on_cpu: bool
#!return: torch.Tensor
def generate_4_2_sparse_tensor(self, run_on_cpu):
dtype = self._get_type()
a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype)
if run_on_cpu:
return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu)
else:
return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu)
#!brief: generate sparse tensor on cpu
#!param: a: torch.Tensor
#!return: torch.Tensor
def __generate_sparse_tensor_cpu(self, a):
if not self._use_specific_meta_data:
for m in range(self.M):
for k in range(0, self.K, 4):
# random choose 2 zero positions
zero_indices = torch.randperm(4)[:2]
a[m, k + zero_indices[0]] = 0
a[m, k + zero_indices[1]] = 0
return a
else:
# use specific meta data
tensor_mask = []
for i in range(self.M):
for j in range(self.K // 4):
meta_val = self.meta_data[i, j]
tmp = []
if meta_val == 0x4:
tmp = [1, 1, 0, 0]
elif meta_val == 0x8:
tmp = [1, 0, 1, 0]
elif meta_val == 0xC:
tmp = [1, 0, 0, 1]
elif meta_val == 0x9:
tmp = [0, 1, 1, 0]
elif meta_val == 0xD:
tmp = [0, 1, 0, 1]
elif meta_val == 0xE:
tmp = [0, 0, 1, 1]
tensor_mask.extend(tmp)
a = torch.reshape(a, (-1,))
mask = torch.tensor(tensor_mask)
a = a * mask
a = torch.reshape(a, (self.M, self.K))
return a
@cute.jit
def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor):
"""Generate a sparse tensor from a dense tensor using metadata"""
assert a.shape[0] == self.M and a.shape[1] == self.K
assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8
num_threads = 128
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
block = (num_threads, 1, 1)
self.kernel(a, meta).launch(grid=grid, block=block)
@cute.kernel
def kernel(self, a: cute.Tensor, meta: cute.Tensor):
"""Apply sparsity mask to input tensor using metadata"""
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
# each thread process 1 ro
row_idx = tidx + bidx * self.M
meta_idx = self.K // 4 // 8
# each thread process 1 row
if row_idx < self.M:
# iterate over each chunk(32 elements)
for i in range(meta_idx):
meta_val = meta[(row_idx, i)]
# iterate over each sparse pattern(4 elements)
for j in range(8):
meta_row = (meta_val >> (j * 4)) & 0xF
idx0 = meta_row & 0x3
idx1 = (meta_row >> 2) & 0x3
r_id0 = 0
r_id1 = 0
# r_id is the idx that value is 0
if idx0 >= 2 and idx1 >= 2:
r_id0 = 0
r_id1 = 1
elif idx0 <= 1 and idx1 <= 1:
r_id0 = 2
r_id1 = 3
else:
r_id0 = idx0 ^ 0b1
r_id1 = idx1 ^ 0b1
row_id0 = r_id0 + i * 32 + j * 4
row_id1 = r_id1 + i * 32 + j * 4
a[row_idx, row_id0] = self.dtype(0.0)
a[row_idx, row_id1] = self.dtype(0.0)
return

View File

@ -0,0 +1,104 @@
import sparse_utils as su
import cutlass
import torch
from cutlass.cute.runtime import from_dlpack
import numpy as np
import pytest
@pytest.mark.L0
def test_sparse_cpu():
M = 128
N = 32
K = 32
L = 1
debug = False
# generate sparse tensor
a = torch.empty(M, K).random_(-5, 5).to(torch.float16)
sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16)
if debug:
sparse_utils.use_specific_meta_data()
a_gen_from_cpu = sparse_utils.generate_sparse_4_2_tensor_with_tensor(a, True)
# print(a_gen_from_cpu)
# generate compressed tensor and meta data
a_compressed_cpu = torch.empty(M, K // 2).to(torch.float16)
meta_data_cpu = torch.empty(M, K // 4 // 8).to(torch.uint32)
compressor = su.Compressor(M, K, L)
compressor.compress(a_gen_from_cpu, a_compressed_cpu, meta_data_cpu, True)
# # test with gemm
b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda()
d = torch.empty(M, N).zero_().to(torch.float16).cuda()
b_tensor = from_dlpack(b)
d_tensor = from_dlpack(d)
a_compressed_cpu_tensor = from_dlpack(a_compressed_cpu.cuda())
meta_data_cpu_tensor = from_dlpack(meta_data_cpu.cuda())
sparse_emulation = su.SparseEmulation(M, N, K, 1)
sparse_emulation(a_compressed_cpu_tensor, b_tensor, d_tensor, meta_data_cpu_tensor)
ref = torch.einsum("mk,nk->mn", a_gen_from_cpu.cpu(), b.cpu())
if debug:
a_ori = a_gen_from_cpu.cpu().numpy()
np.savetxt("a.txt", a_ori, fmt="%f")
a_compressed_cpu_ori = a_compressed_cpu.cpu().numpy()
np.savetxt("a_compressed_cpu.txt", a_compressed_cpu_ori, fmt="%f")
meta_data_cpu_ori = meta_data_cpu.cpu().numpy()
np.savetxt("meta_data_cpu.txt", meta_data_cpu_ori, fmt="%f")
d_ori = d.cpu().numpy()
np.savetxt("d.txt", d_ori, fmt="%f")
ref_ori = ref.cpu().numpy()
np.savetxt("ref.txt", ref_ori, fmt="%f")
torch.testing.assert_close(d.cpu(), ref)
print("cpu d == ref")
@pytest.mark.L0
def test_sparse_cuda():
M = 128
N = 32
K = 32
L = 1
debug = False
sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16)
if debug:
sparse_utils.use_specific_meta_data()
# generate sparse tensor
a = torch.empty(M, K).random_(-5, 5).to(torch.float16).cuda()
a_gen_from_cuda = sparse_utils.generate_4_2_sparse_tensor(False)
# print(a_gen_from_cuda)
# generate compressed tensor and meta data
a_compressed_cuda = torch.empty(M, K // 2).to(torch.float16).cuda()
meta_data_cuda = torch.empty(M, K // 4 // 8).to(torch.uint32).cuda()
compressor = su.Compressor(M, K, L)
compressor.compress(a_gen_from_cuda, a_compressed_cuda, meta_data_cuda, False)
# test with gemm
b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda()
d = torch.empty(M, N).zero_().to(torch.float16).cuda()
b_tensor = from_dlpack(b)
d_tensor = from_dlpack(d)
a_compressed_cuda_tensor = from_dlpack(a_compressed_cuda)
meta_data_cuda_tensor = from_dlpack(meta_data_cuda)
sparse_emulation = su.SparseEmulation(M, N, K, 1)
sparse_emulation(
a_compressed_cuda_tensor, b_tensor, d_tensor, meta_data_cuda_tensor
)
ref = torch.einsum("mk,nk->mn", a_gen_from_cuda.cpu(), b.cpu())
if debug:
a_ori = a_gen_from_cuda.cpu().numpy()
np.savetxt("a.txt", a_ori, fmt="%f")
a_compressed_cuda_ori = a_compressed_cuda.cpu().numpy()
np.savetxt("a_compressed_cuda.txt", a_compressed_cuda_ori, fmt="%f")
meta_data_cuda_ori = meta_data_cuda.cpu().numpy()
np.savetxt("meta_data_cuda.txt", meta_data_cuda_ori, fmt="%f")
d_ori = d.cpu().numpy()
np.savetxt("d.txt", d_ori, fmt="%f")
ref_ori = ref.cpu().numpy()
np.savetxt("ref.txt", ref_ori, fmt="%f")
torch.testing.assert_close(d.cpu(), ref)
print("cuda d == ref")
if __name__ == "__main__":
cutlass.cuda.initialize_cuda_context()
test_sparse_cpu()
test_sparse_cuda()