v4.3 update. (#2709)
* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
This commit is contained in:
0
examples/python/CuTeDSL/utils/__init__.py
Normal file
0
examples/python/CuTeDSL/utils/__init__.py
Normal file
975
examples/python/CuTeDSL/utils/fmha_helpers.py
Normal file
975
examples/python/CuTeDSL/utils/fmha_helpers.py
Normal file
@ -0,0 +1,975 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from typing import Tuple, Optional
|
||||
import cutlass
|
||||
from cutlass.cute.typing import Boolean
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Int32,
|
||||
Float32,
|
||||
min,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
)
|
||||
from cutlass.utils.hardware_info import HardwareInfo
|
||||
from cutlass.utils import WorkTileInfo
|
||||
import cutlass.cute as cute
|
||||
|
||||
##############################################################################
|
||||
# Fmha static tile scheduler
|
||||
##############################################################################
|
||||
|
||||
|
||||
class FmhaStaticTileSchedulerParams:
|
||||
"""A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler.
|
||||
|
||||
This class holds the configuration parameters needed to initialize and configure
|
||||
the tile scheduler for FMHA operations.
|
||||
|
||||
:ivar is_persistent: Whether to use persistent kernel mode.
|
||||
:type is_persistent: bool
|
||||
:ivar problem_shape_mbh: Problem shape in (M, B, H) format.
|
||||
:type problem_shape_mbh: cute.Shape
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_persistent: bool,
|
||||
problem_shape_mbh: cute.Shape,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""
|
||||
Initializes the FmhaStaticTileSchedulerParams with the given parameters.
|
||||
|
||||
:param is_persistent: Whether to use persistent kernel mode.
|
||||
:type is_persistent: bool
|
||||
:param problem_shape_mbh: Problem shape in (M, B, H) format.
|
||||
:type problem_shape_mbh: cute.Shape
|
||||
"""
|
||||
self.is_persistent = is_persistent
|
||||
self.problem_shape_mbh = problem_shape_mbh
|
||||
self._loc = loc
|
||||
self._ip = ip
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values, self._values_pos = [], []
|
||||
for obj in [self.problem_shape_mbh]:
|
||||
obj_values = extract_mlir_values(obj)
|
||||
values += obj_values
|
||||
self._values_pos.append(len(obj_values))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
obj_list = []
|
||||
for obj, n_items in zip([self.problem_shape_mbh], self._values_pos):
|
||||
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
|
||||
values = values[n_items:]
|
||||
return FmhaStaticTileSchedulerParams(
|
||||
self.is_persistent, *(tuple(obj_list)), loc=self._loc
|
||||
)
|
||||
|
||||
|
||||
class FmhaStaticTileScheduler:
|
||||
"""A static tile scheduler for FMHA (Fused Multi-Head Attention) operations.
|
||||
|
||||
This class manages the scheduling of work tiles for FMHA kernels, supporting
|
||||
both persistent and non-persistent kernel modes. It tracks the current work
|
||||
position and advances through the problem space efficiently.
|
||||
|
||||
:ivar _params: Scheduler parameters.
|
||||
:type _params: FmhaStaticTileSchedulerParams
|
||||
:ivar _blk_coord: Block coordinates.
|
||||
:type _blk_coord: cute.Coord
|
||||
:ivar _grid_shape: Grid shape for the kernel.
|
||||
:type _grid_shape: cute.Shape
|
||||
:ivar _is_persistent: Whether to use persistent kernel mode.
|
||||
:type _is_persistent: bool
|
||||
:ivar _current_work_linear_idx: Current linear work index.
|
||||
:type _current_work_linear_idx: Int32
|
||||
:ivar _problem_shape_mbh: Problem shape in (M, B, H) format.
|
||||
:type _problem_shape_mbh: cute.Layout
|
||||
:ivar _num_blocks: Number of blocks in the problem.
|
||||
:type _num_blocks: Int32
|
||||
:ivar _is_first_block: Whether this is the first block.
|
||||
:type _is_first_block: bool
|
||||
:ivar num_persistent_sm: Number of persistent SMs.
|
||||
:type num_persistent_sm: Int32
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: FmhaStaticTileSchedulerParams,
|
||||
current_work_linear_idx: Int32,
|
||||
blk_coord: cute.Coord,
|
||||
grid_shape: cute.Shape,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""
|
||||
Initializes the FmhaStaticTileScheduler with the given parameters.
|
||||
|
||||
:param params: Scheduler parameters.
|
||||
:type params: FmhaStaticTileSchedulerParams
|
||||
:param current_work_linear_idx: Current linear work index.
|
||||
:type current_work_linear_idx: Int32
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param grid_shape: Grid shape for the kernel.
|
||||
:type grid_shape: cute.Shape
|
||||
"""
|
||||
self._params = params
|
||||
self._blk_coord = blk_coord
|
||||
self._grid_shape = grid_shape
|
||||
self._is_persistent = params.is_persistent
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self._problem_shape_mbh = cute.make_layout(
|
||||
params.problem_shape_mbh, loc=loc, ip=ip
|
||||
)
|
||||
self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip)
|
||||
self._is_first_block = True
|
||||
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
|
||||
self._loc = loc
|
||||
self._ip = ip
|
||||
|
||||
# called by host
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: FmhaStaticTileSchedulerParams,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Shape:
|
||||
"""
|
||||
Determine the grid shape for the FMHA kernel.
|
||||
|
||||
For persistent kernels, the grid shape is limited by the number of SMs
|
||||
(Streaming Multiprocessors) available on the device. For non-persistent
|
||||
kernels, the grid shape matches the problem shape.
|
||||
|
||||
:param params: Scheduler parameters.
|
||||
:type params: FmhaStaticTileSchedulerParams
|
||||
|
||||
:return: Grid shape as (M, B, H) tuple.
|
||||
:rtype: cute.Shape
|
||||
"""
|
||||
if params.is_persistent:
|
||||
hardware_info = HardwareInfo()
|
||||
sm_count = hardware_info.get_device_multiprocessor_count()
|
||||
return (
|
||||
min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)),
|
||||
1,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
return params.problem_shape_mbh
|
||||
|
||||
@staticmethod
|
||||
def check_valid_work_for_seqlen_q(
|
||||
q_tiler: int,
|
||||
current_idx: Int32,
|
||||
seqlen_q: Int32,
|
||||
) -> Boolean:
|
||||
"""
|
||||
Check if the current work index is valid for the given query sequence length.
|
||||
|
||||
This method verifies that the current work tile index multiplied by the
|
||||
query tiler size is within the bounds of the query sequence length.
|
||||
|
||||
:param q_tiler: Query tiler size.
|
||||
:type q_tiler: int
|
||||
:param current_idx: Current work index.
|
||||
:type current_idx: Int32
|
||||
:param seqlen_q: Query sequence length.
|
||||
:type seqlen_q: Int32
|
||||
|
||||
:return: True if the work is valid, False otherwise.
|
||||
:rtype: Boolean
|
||||
"""
|
||||
return current_idx * q_tiler < seqlen_q
|
||||
|
||||
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
"""
|
||||
Get information about the current work tile.
|
||||
|
||||
Determines if the current work is valid and computes the tile coordinates
|
||||
based on whether the kernel is persistent or non-persistent.
|
||||
|
||||
:return: WorkTileInfo containing tile coordinates and validity flag.
|
||||
:rtype: WorkTileInfo
|
||||
"""
|
||||
is_valid = (
|
||||
self._current_work_linear_idx < self._num_blocks
|
||||
if self._is_persistent
|
||||
else self._is_first_block
|
||||
)
|
||||
|
||||
blk_coord = (0, 0, 0)
|
||||
if self._is_persistent:
|
||||
blk_coord = self._problem_shape_mbh.get_hier_coord(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
blk_coord = self._blk_coord
|
||||
|
||||
# cur_tile_coord is (mid, 0, (bid, hid))
|
||||
cur_tile_coord = (
|
||||
blk_coord[0],
|
||||
0,
|
||||
(blk_coord[1], blk_coord[2]),
|
||||
)
|
||||
|
||||
return WorkTileInfo(cur_tile_coord, is_valid)
|
||||
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None):
|
||||
"""
|
||||
Get the initial work tile information.
|
||||
|
||||
:return: Initial WorkTileInfo.
|
||||
:rtype: WorkTileInfo
|
||||
"""
|
||||
return self.get_current_work(loc=loc, ip=ip)
|
||||
|
||||
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
|
||||
"""
|
||||
Advance to the next work tile.
|
||||
|
||||
For persistent kernels, advances by the number of persistent SMs.
|
||||
For non-persistent kernels, marks that the first block has been processed.
|
||||
|
||||
:param advance_count: Number of steps to advance (default: 1).
|
||||
:type advance_count: int
|
||||
"""
|
||||
if self._is_persistent:
|
||||
self._current_work_linear_idx += advance_count * self.num_persistent_sm
|
||||
self._is_first_block = False
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values = extract_mlir_values(self._params)
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self._blk_coord))
|
||||
values.extend(extract_mlir_values(self._grid_shape))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 10
|
||||
new_params = new_from_mlir_values(self._params, values[0:3])
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[3]]
|
||||
)
|
||||
new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7])
|
||||
new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:])
|
||||
return FmhaStaticTileScheduler(
|
||||
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
|
||||
)
|
||||
|
||||
|
||||
def create_fmha_static_tile_scheduler(
|
||||
params: FmhaStaticTileSchedulerParams,
|
||||
blk_coord: cute.Coord,
|
||||
grid_shape: cute.Shape,
|
||||
) -> FmhaStaticTileScheduler:
|
||||
"""
|
||||
Create a new FMHA static tile scheduler.
|
||||
|
||||
:param params: Scheduler parameters.
|
||||
:type params: FmhaStaticTileSchedulerParams
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param grid_shape: Grid shape.
|
||||
:type grid_shape: cute.Shape
|
||||
|
||||
:return: New FmhaStaticTileScheduler instance.
|
||||
:rtype: FmhaStaticTileScheduler
|
||||
"""
|
||||
return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
|
||||
|
||||
|
||||
def create_fmha_static_tile_scheduler_params(
|
||||
is_persistent: bool,
|
||||
problem_shape_mbh: cute.Shape,
|
||||
) -> FmhaStaticTileSchedulerParams:
|
||||
"""
|
||||
Create FMHA static tile scheduler parameters.
|
||||
|
||||
:param is_persistent: Whether to use persistent kernel mode.
|
||||
:type is_persistent: bool
|
||||
:param problem_shape_mbh: Problem shape in (M, B, H) format.
|
||||
:type problem_shape_mbh: cute.Shape
|
||||
|
||||
:return: New FmhaStaticTileSchedulerParams instance.
|
||||
:rtype: FmhaStaticTileSchedulerParams
|
||||
"""
|
||||
return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh)
|
||||
|
||||
|
||||
def compute_grid(
|
||||
o_shape: cute.Shape,
|
||||
cta_tiler: Tuple[int, int, int],
|
||||
is_persistent: bool,
|
||||
) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]:
|
||||
"""
|
||||
Compute grid parameters for FMHA operation.
|
||||
|
||||
This function calculates the appropriate grid shape and scheduler parameters
|
||||
based on the output tensor shape, CTA (Cooperative Thread Array) tiler,
|
||||
and whether to use persistent kernel mode.
|
||||
|
||||
The output tensor o has shape (s, d, ((h_r, h_k), b)) where:
|
||||
- s: sequence length
|
||||
- d: head dimension
|
||||
- h_r: number of heads for query
|
||||
- h_k: number of heads for key
|
||||
- b: batch size
|
||||
|
||||
:param o_shape: Output tensor shape for grid computation.
|
||||
:type o_shape: cute.Shape
|
||||
:param cta_tiler: CTA tiler dimensions (M, N, K).
|
||||
:type cta_tiler: Tuple[int, int, int]
|
||||
:param is_persistent: Whether to use persistent kernel mode.
|
||||
:type is_persistent: bool
|
||||
|
||||
:return: Tuple of (scheduler_params, grid_shape).
|
||||
:rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]
|
||||
"""
|
||||
tile_sched_params = create_fmha_static_tile_scheduler_params(
|
||||
is_persistent,
|
||||
(
|
||||
cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]),
|
||||
cute.size(o_shape[2][0]),
|
||||
cute.size(o_shape[2][1]),
|
||||
),
|
||||
)
|
||||
grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params)
|
||||
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Fused Mask
|
||||
##############################################################################
|
||||
|
||||
|
||||
class MaskEnum(enum.Enum):
|
||||
"""Enumeration of mask types for FMHA operations.
|
||||
|
||||
- RESIDUAL_MASK: Residual mask for handling variable sequence lengths
|
||||
- WINDOW_MASK: Window mask for attention which also includes causal and no mask
|
||||
- WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k
|
||||
- WINDOW_MASK_BWD: Window mask for backward pass
|
||||
- WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k
|
||||
"""
|
||||
|
||||
RESIDUAL_MASK = enum.auto()
|
||||
RESIDUAL_MASK_BWD = enum.auto()
|
||||
WINDOW_MASK = enum.auto()
|
||||
WINDOW_MASK_INFERENCE = enum.auto()
|
||||
WINDOW_MASK_BWD = enum.auto()
|
||||
WINDOW_MASK_BWD_INFERENCE = enum.auto()
|
||||
|
||||
|
||||
class FusedMask:
|
||||
"""A fused mask implementation for FMHA operations.
|
||||
|
||||
This class handles different types of attention masks including no mask,
|
||||
residual mask for variable sequence lengths, and causal mask for
|
||||
autoregressive attention patterns.
|
||||
|
||||
The class provides methods to:
|
||||
- Calculate trip counts for different mask types
|
||||
- Apply masks to attention scores
|
||||
- Handle masked and unmasked trip calculations
|
||||
"""
|
||||
|
||||
def get_trip_count(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Calculate the number of trips needed for the current block.
|
||||
|
||||
The trip count depends on the mask type and the block coordinates.
|
||||
For causal masks, it considers the autoregressive constraint.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
|
||||
:return: Number of trips needed.
|
||||
:rtype: Int32
|
||||
"""
|
||||
result = 0
|
||||
offset = 0
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
|
||||
offset = seqlen_k - seqlen_q
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
|
||||
offset = seqlen_q - seqlen_k
|
||||
if cutlass.const_expr(mask_type == MaskEnum.RESIDUAL_MASK):
|
||||
result = cute.ceil_div(seqlen_k, tile_shape[1])
|
||||
if cutlass.const_expr(mask_type is MaskEnum.RESIDUAL_MASK_BWD):
|
||||
result = cute.ceil_div(seqlen_q, tile_shape[0])
|
||||
if cutlass.const_expr(
|
||||
mask_type == MaskEnum.WINDOW_MASK
|
||||
or mask_type == MaskEnum.WINDOW_MASK_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_right is None):
|
||||
result = cute.ceil_div(seqlen_k, tile_shape[1])
|
||||
else:
|
||||
max_idx_q = (blk_coord[0] + 1) * tile_shape[0]
|
||||
idx_k = max_idx_q + offset + window_size_right
|
||||
tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1])
|
||||
max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
|
||||
result = min(max_blocks_k, tmp_blocks_k)
|
||||
if cutlass.const_expr(
|
||||
mask_type == MaskEnum.WINDOW_MASK_BWD
|
||||
or mask_type == MaskEnum.WINDOW_MASK_BWD_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_left is None):
|
||||
result = cute.ceil_div(seqlen_q, tile_shape[0])
|
||||
else:
|
||||
max_idx_k = (blk_coord[1] + 1) * tile_shape[1]
|
||||
idx_k = max_idx_k + offset + window_size_left
|
||||
tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0])
|
||||
max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0])
|
||||
result = min(max_blocks_q, tmp_blocks_q)
|
||||
start_block = FusedMask.get_trip_start(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
result = result - start_block
|
||||
return result
|
||||
|
||||
@cute.jit
|
||||
def get_trip_start(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Get the start of the trip for the current block.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
"""
|
||||
result = 0
|
||||
offset = 0
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
|
||||
offset = seqlen_k - seqlen_q
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
|
||||
offset = seqlen_q - seqlen_k
|
||||
if cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK
|
||||
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_left is not None):
|
||||
min_idx_q = blk_coord[0] * tile_shape[0]
|
||||
idx_k = min_idx_q + offset - window_size_left
|
||||
tmp_blocks_k = idx_k // tile_shape[1]
|
||||
result = max(tmp_blocks_k, result)
|
||||
if cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK_BWD
|
||||
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_right is not None):
|
||||
min_idx_k = blk_coord[1] * tile_shape[1]
|
||||
idx_q = min_idx_k + offset - window_size_right
|
||||
tmp_blocks_q = idx_q // tile_shape[0]
|
||||
result = max(tmp_blocks_q, result)
|
||||
return result
|
||||
|
||||
@cute.jit
|
||||
def get_leading_mask_id(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Get the begin and end tile idx for the leading mask.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
|
||||
:return: Tuple of (begin, end) tile idx for the leading mask.
|
||||
:rtype: Tuple[Int32, Int32]
|
||||
"""
|
||||
offset = 0
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
|
||||
offset = seqlen_k - seqlen_q
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
|
||||
offset = seqlen_q - seqlen_k
|
||||
leading_mask_begin = FusedMask.get_trip_start(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
trip_count = FusedMask.get_trip_count(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
|
||||
leading_mask_end = leading_mask_begin
|
||||
if cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK
|
||||
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_left is not None):
|
||||
min_idx_q = (
|
||||
(blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left
|
||||
)
|
||||
leading_mask_end = min(
|
||||
cute.ceil_div(min_idx_q, tile_shape[1]) - 1,
|
||||
trip_count + leading_mask_begin - 1,
|
||||
)
|
||||
else:
|
||||
leading_mask_end = leading_mask_begin - 1
|
||||
elif cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK_BWD
|
||||
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_right is not None):
|
||||
min_idx_k = (
|
||||
(blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right
|
||||
)
|
||||
leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1
|
||||
else:
|
||||
leading_mask_end = leading_mask_begin - 1
|
||||
return leading_mask_begin, leading_mask_end
|
||||
|
||||
@cute.jit
|
||||
def get_trailing_mask_id(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Tuple[Optional[Int32], Optional[Int32]]:
|
||||
"""
|
||||
Get the begin and end tile idx for the trailing mask.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
|
||||
:return: Tuple of (begin, end) tile idx for the trailing mask.
|
||||
:rtype: Tuple[Int32, Int32]
|
||||
"""
|
||||
offset = 0
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
|
||||
offset = seqlen_k - seqlen_q
|
||||
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
|
||||
offset = seqlen_q - seqlen_k
|
||||
trip_start = FusedMask.get_trip_start(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
trip_count = FusedMask.get_trip_count(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
|
||||
trailing_mask_begin, trailing_mask_end = None, None
|
||||
if cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK
|
||||
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
|
||||
):
|
||||
if cutlass.const_expr(window_size_right is not None):
|
||||
min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right
|
||||
trailing_mask_begin = min(
|
||||
min_idx_q // tile_shape[1], trip_count + trip_start - 1
|
||||
)
|
||||
trailing_mask_end = trip_count + trip_start - 1
|
||||
else:
|
||||
# last tile, we always apply mask on it regardless whether it's a residual tile
|
||||
trailing_mask_begin = trip_count + trip_start - 1
|
||||
trailing_mask_end = trip_count + trip_start - 1
|
||||
else:
|
||||
if cutlass.const_expr(window_size_left is not None):
|
||||
min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1
|
||||
max_idx_k = (
|
||||
(blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left
|
||||
)
|
||||
trailing_mask_begin = min(
|
||||
cute.ceil_div(min_idx_k, tile_shape[0]) - 1,
|
||||
trip_count + trip_start - 1,
|
||||
)
|
||||
trailing_mask_end = min(
|
||||
cute.ceil_div(max_idx_k, tile_shape[0]) - 1,
|
||||
trip_count + trip_start - 1,
|
||||
)
|
||||
else:
|
||||
# last tile, we always apply mask on it regardless whether it's a residual tile
|
||||
trailing_mask_begin = trip_count + trip_start - 1
|
||||
trailing_mask_end = trip_count + trip_start - 1
|
||||
|
||||
return trailing_mask_begin, trailing_mask_end
|
||||
|
||||
@cute.jit
|
||||
def get_masked_leading_count(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Calculate the number of masked trips for the leading mask.
|
||||
|
||||
This is used for blocks that need special handling due to masking.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
|
||||
:return: Number of masked trips.
|
||||
:rtype: Int32
|
||||
"""
|
||||
result = 0
|
||||
if cutlass.const_expr(
|
||||
mask_type is not MaskEnum.RESIDUAL_MASK
|
||||
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
|
||||
):
|
||||
if cutlass.const_expr(
|
||||
window_size_left is not None or window_size_right is not None
|
||||
):
|
||||
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
result = max(leading_mask_end - leading_mask_begin + 1, 0)
|
||||
|
||||
return result
|
||||
|
||||
@cute.jit
|
||||
def get_masked_trailing_count(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
rem_count: Optional[Int32] = 0,
|
||||
) -> Int32:
|
||||
"""
|
||||
Calculate the number of masked trips for the trailing mask.
|
||||
|
||||
This is used for blocks that need special handling due to masking.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
:param rem_count: Remaining count from previous calculations.
|
||||
:type rem_count: Int32
|
||||
|
||||
:return: Number of masked trips.
|
||||
:rtype: Int32
|
||||
"""
|
||||
result = 0
|
||||
|
||||
if cutlass.const_expr(
|
||||
mask_type is not MaskEnum.RESIDUAL_MASK
|
||||
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
|
||||
):
|
||||
if cutlass.const_expr(
|
||||
window_size_left is not None or window_size_right is not None
|
||||
):
|
||||
trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
if cutlass.const_expr(
|
||||
trailing_mask_begin is not None and trailing_mask_end is not None
|
||||
):
|
||||
if trailing_mask_begin <= leading_mask_end:
|
||||
result = max(trailing_mask_end - leading_mask_end, 0)
|
||||
else:
|
||||
result = max(trailing_mask_end - trailing_mask_begin + 1, 0)
|
||||
else:
|
||||
if seqlen_k % tile_shape[1] != 0:
|
||||
result = 1
|
||||
else:
|
||||
result = 0
|
||||
|
||||
return result + rem_count
|
||||
|
||||
@cute.jit
|
||||
def get_unmasked_trip_count(
|
||||
mask_type: MaskEnum,
|
||||
blk_coord: cute.Coord,
|
||||
tile_shape: cute.Shape,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[Int32] = None,
|
||||
window_size_right: Optional[Int32] = None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Calculate the number of unmasked trips for the current block.
|
||||
|
||||
This represents the number of trips that don't require special
|
||||
masking treatment.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param blk_coord: Block coordinates.
|
||||
:type blk_coord: cute.Coord
|
||||
:param tile_shape: Shape of the tile.
|
||||
:type tile_shape: cute.Shape
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Int32
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[Int32]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[Int32]
|
||||
|
||||
:return: Number of unmasked trips.
|
||||
:rtype: Int32
|
||||
"""
|
||||
result = (
|
||||
FusedMask.get_trip_count(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
- FusedMask.get_masked_leading_count(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
- FusedMask.get_masked_trailing_count(
|
||||
mask_type,
|
||||
blk_coord,
|
||||
tile_shape,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
0,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@cute.jit
|
||||
def apply_mask(
|
||||
mask_type: MaskEnum,
|
||||
acc_qk: cute.Tensor,
|
||||
index_qk: cute.Tensor,
|
||||
seqlen_q: Int32,
|
||||
seqlen_k: Int32,
|
||||
window_size_left: Optional[int] = None,
|
||||
window_size_right: Optional[int] = None,
|
||||
index_transform: cutlass.Constexpr = lambda index_q, index_k: (
|
||||
index_q,
|
||||
index_k,
|
||||
),
|
||||
):
|
||||
"""
|
||||
Apply the appropriate mask to the attention scores.
|
||||
|
||||
This method modifies the attention scores (acc_qk) based on the mask type
|
||||
and the positions in the index tensor.
|
||||
|
||||
:param mask_type: Type of mask to use
|
||||
:type mask_type: utils.MaskEnum
|
||||
:param acc_qk: Accumulated QK attention scores tensor.
|
||||
:type acc_qk: cute.Tensor
|
||||
:param index_qk: Index tensor containing position information.
|
||||
:type index_qk: cute.Tensor
|
||||
:param seqlen_k: Key sequence length for attention computation.
|
||||
:type seqlen_k: Int32
|
||||
:param seqlen_q: Query sequence length for attention computation.
|
||||
:type seqlen_q: Optional[int]
|
||||
:param window_size_left: Left-side sliding window size for attention masking.
|
||||
:type window_size_left: Optional[int]
|
||||
:param window_size_right: Right-side sliding window size for attention masking.
|
||||
:type window_size_right: Optional[int]
|
||||
"""
|
||||
|
||||
tidx, tidy, tidx = cute.arch.thread_idx()
|
||||
offset = 0
|
||||
offset = (
|
||||
seqlen_k - seqlen_q
|
||||
if cutlass.const_expr(
|
||||
mask_type is MaskEnum.WINDOW_MASK_INFERENCE
|
||||
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
|
||||
)
|
||||
else 0
|
||||
)
|
||||
for i in cutlass.range_constexpr(cute.size(acc_qk)):
|
||||
index_q, index_k = index_transform(*index_qk[i])
|
||||
if cutlass.const_expr(
|
||||
window_size_left is not None or window_size_right is not None
|
||||
):
|
||||
if cutlass.const_expr(window_size_left is None):
|
||||
if index_q + offset + window_size_right < index_k:
|
||||
acc_qk[i] = -Float32.inf
|
||||
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
|
||||
acc_qk[i] = -Float32.inf
|
||||
elif cutlass.const_expr(window_size_right is None):
|
||||
if index_q + offset - window_size_left > index_k:
|
||||
acc_qk[i] = -Float32.inf
|
||||
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
|
||||
acc_qk[i] = -Float32.inf
|
||||
else:
|
||||
max_K_index = min(index_q + offset + window_size_right, seqlen_k)
|
||||
min_K_index = max(0, index_q + offset - window_size_left)
|
||||
if index_k > max_K_index or index_k < min_K_index:
|
||||
acc_qk[i] = -Float32.inf
|
||||
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
|
||||
acc_qk[i] = -Float32.inf
|
||||
|
||||
if cutlass.const_expr(
|
||||
mask_type == MaskEnum.RESIDUAL_MASK
|
||||
or mask_type == MaskEnum.RESIDUAL_MASK_BWD
|
||||
):
|
||||
if index_k >= seqlen_k or index_q >= seqlen_q:
|
||||
acc_qk[i] = -Float32.inf
|
||||
457
examples/python/CuTeDSL/utils/sparse_utils.py
Normal file
457
examples/python/CuTeDSL/utils/sparse_utils.py
Normal file
@ -0,0 +1,457 @@
|
||||
import numpy as np
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import torch
|
||||
|
||||
|
||||
@cute.jit
|
||||
def print_tensor_dlpack(src: cute.Tensor):
|
||||
print(src)
|
||||
cute.print_tensor(src)
|
||||
|
||||
|
||||
# Sparse emulation
|
||||
class SparseEmulation:
|
||||
def __init__(self, M: int, N: int, K: int, L: int):
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.K = K
|
||||
self.L = L
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
||||
"""Sparse emulation"""
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.kernel(a, b, d, e).launch(grid=grid, block=block)
|
||||
return
|
||||
|
||||
@cute.kernel
|
||||
def kernel(self, a: cute.Tensor, b: cute.Tensor, d: cute.Tensor, e: cute.Tensor):
|
||||
"""CUDA kernel to emulate sparse tensor core"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
if row_idx < self.M:
|
||||
# each thread process 1 row
|
||||
for col in range(self.N):
|
||||
# each meta_idx stands for 32 elements
|
||||
for e_idx in range(meta_idx):
|
||||
meta_val = e[(row_idx, e_idx)]
|
||||
for k in range(8):
|
||||
# each k stands for 4 elements
|
||||
meta_row = (meta_val >> (k * 4)) & 0xF
|
||||
idx0 = meta_row & 0x3
|
||||
idx1 = (meta_row >> 2) & 0x3
|
||||
# calculate the idx in b tensor which has value in A tensor
|
||||
km = e_idx * 16 + k * 2
|
||||
km_1 = km + 1
|
||||
kn = e_idx * 32 + k * 4 + idx0
|
||||
kn_1 = e_idx * 32 + k * 4 + idx1
|
||||
d[row_idx, col] += a[row_idx, km] * b[col, kn]
|
||||
d[row_idx, col] += a[row_idx, km_1] * b[col, kn_1]
|
||||
return
|
||||
|
||||
|
||||
# Compressor
|
||||
# compress a sparse tensor to a dense tensor && generate metadata
|
||||
class Compressor:
|
||||
def __init__(self, M: int, K: int, L: int):
|
||||
self.M = M
|
||||
self.K = K
|
||||
self.L = L
|
||||
self.pos_map = {
|
||||
0x4: [0, 1],
|
||||
0x8: [0, 2],
|
||||
0xC: [0, 3],
|
||||
0x9: [1, 2],
|
||||
0xD: [1, 3],
|
||||
0xE: [2, 3],
|
||||
}
|
||||
|
||||
@cute.jit
|
||||
def _init__(self, a: cute.Tensor):
|
||||
self.__init__(a.shape[0], a.shape[1], a.shape[2])
|
||||
|
||||
def compress(self, a, a_compressed, meta, run_on_cpu: bool):
|
||||
if run_on_cpu:
|
||||
if a.device.type != "cpu":
|
||||
raise ValueError("a must be on cpu")
|
||||
return self.__compress_on_cpu(a, a_compressed, meta)
|
||||
else:
|
||||
if a.device.type != "cuda":
|
||||
raise ValueError("a must be on cuda")
|
||||
return self.__compress_on_cuda(a, a_compressed, meta)
|
||||
|
||||
def __compress_on_cpu(self, a, a_compressed, meta):
|
||||
"""
|
||||
compress the tensor on cpu
|
||||
# Convert to 4-bit metadata value
|
||||
# The metadata value represents which 2 elements are non-zero
|
||||
# 0x4: [1,1,0,0] - first two elements are non-zero
|
||||
# 0x8: [1,0,1,0] - first and third elements are non-zero
|
||||
# 0xC: [1,0,0,1] - first and fourth elements are non-zero
|
||||
# 0x9: [0,1,1,0] - second and third elements are non-zero
|
||||
# 0xD: [0,1,0,1] - second and fourth elements are non-zero
|
||||
# 0xE: [0,0,1,1] - third and fourth elements are non-zero
|
||||
# special case:
|
||||
# [0,0,0,0] == [0,0,1,1]
|
||||
# [1,0,0,0] == [1,0,0,1]
|
||||
# [0,1,0,0] == [0,1,0,1]
|
||||
# [0,0,1,0] == [0,0,1,1]
|
||||
# [0,0,0,1] == [0,0,1,1]
|
||||
"""
|
||||
M, K = a.shape
|
||||
assert a_compressed.shape == (
|
||||
M,
|
||||
K // 2,
|
||||
), f"Expected a_compressed shape {(M, K // 2)}, got {a_compressed.shape}"
|
||||
assert meta.shape == (
|
||||
M,
|
||||
K // 4 // 8,
|
||||
), f"Expected meta shape {(M, K // 4 // 8)}, got {meta.shape}"
|
||||
for m in range(M):
|
||||
k_meta = 0
|
||||
for k in range(0, K, 4):
|
||||
chunk = a[m, k : k + 4]
|
||||
|
||||
non_zero_indices = torch.nonzero(chunk).squeeze()
|
||||
meta_val = 0xE
|
||||
if torch.equal(non_zero_indices, torch.tensor([0, 1])):
|
||||
meta_val = 0x4
|
||||
elif torch.equal(non_zero_indices, torch.tensor([0, 2])):
|
||||
meta_val = 0x8
|
||||
elif torch.equal(non_zero_indices, torch.tensor([0, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(0)
|
||||
):
|
||||
meta_val = 0xC
|
||||
elif torch.equal(non_zero_indices, torch.tensor([1, 2])):
|
||||
meta_val = 0x9
|
||||
elif torch.equal(non_zero_indices, torch.tensor([1, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(1)
|
||||
):
|
||||
meta_val = 0xD
|
||||
elif torch.equal(non_zero_indices, torch.tensor([2, 3])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(2)
|
||||
):
|
||||
meta_val = 0xE
|
||||
elif torch.equal(non_zero_indices, torch.tensor([])) or torch.equal(
|
||||
non_zero_indices, torch.tensor(3)
|
||||
):
|
||||
meta_val = 0xE
|
||||
else:
|
||||
raise ValueError(f"Invalid non-zero pattern: {non_zero_indices}")
|
||||
meta_idx = k // 4 // 8
|
||||
meta_bit_pos = (k // 4) % 8
|
||||
if k_meta == meta_idx:
|
||||
k_meta = meta_idx + 1
|
||||
meta[m, meta_idx] = 0
|
||||
meta[m, meta_idx] |= meta_val << (meta_bit_pos * 4)
|
||||
compressed_idx = k // 2
|
||||
index = self.pos_map[meta_val]
|
||||
a_compressed[m, compressed_idx] = chunk[index[0]]
|
||||
a_compressed[m, compressed_idx + 1] = chunk[index[1]]
|
||||
|
||||
def __compress_on_cuda(self, a, a_compressed, meta):
|
||||
"""
|
||||
compress the tensor on cuda
|
||||
"""
|
||||
a_tensor = from_dlpack(a)
|
||||
a_compressed_tensor = from_dlpack(a_compressed)
|
||||
meta_tensor = from_dlpack(meta)
|
||||
self.compress_on_cuda_impl(a_tensor, a_compressed_tensor, meta_tensor)
|
||||
return
|
||||
|
||||
@cute.jit
|
||||
def compress_on_cuda_impl(
|
||||
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
||||
):
|
||||
"""Compress the input tensor using the metadata"""
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.compressor_impl(a, a_compressed, meta).launch(grid=grid, block=block)
|
||||
|
||||
@cute.kernel
|
||||
def compressor_impl(
|
||||
self, a: cute.Tensor, a_compressed: cute.Tensor, meta: cute.Tensor
|
||||
):
|
||||
"""CUDA kernel to compress the tensor"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
m = a.shape[0]
|
||||
k = a.shape[1]
|
||||
|
||||
# each thread process 1 row
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
if row_idx < self.M:
|
||||
# each meta_idx stands for 32 elements
|
||||
for i in range(meta_idx):
|
||||
meta[row_idx, i] = 0
|
||||
# each k stands for 4 elements
|
||||
for j in range(8):
|
||||
val = a[row_idx, i * 32 + j * 4]
|
||||
val_1 = a[row_idx, i * 32 + j * 4 + 1]
|
||||
val_2 = a[row_idx, i * 32 + j * 4 + 2]
|
||||
val_3 = a[row_idx, i * 32 + j * 4 + 3]
|
||||
value_idx = 0
|
||||
value_idx_1 = 0
|
||||
value_idx_2 = 0
|
||||
value_idx_3 = 0
|
||||
pos0 = 0
|
||||
pos1 = 0
|
||||
if val != 0:
|
||||
value_idx = 1
|
||||
pos0 = 0
|
||||
if val_1 != 0:
|
||||
value_idx_1 = 1
|
||||
if val_2 != 0:
|
||||
value_idx_2 = 1
|
||||
if val_3 != 0:
|
||||
value_idx_3 = 1
|
||||
pos = [value_idx, value_idx_1, value_idx_2, value_idx_3]
|
||||
tmp = 0
|
||||
if pos == [0, 0, 0, 0]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [1, 0, 0, 0]:
|
||||
tmp = 0xC
|
||||
pos0 = 0
|
||||
pos1 = 3
|
||||
elif pos == [0, 1, 0, 0]:
|
||||
tmp = 0xD
|
||||
pos0 = 1
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 1, 0]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 0, 1]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
elif pos == [1, 1, 0, 0]:
|
||||
tmp = 0x4
|
||||
pos0 = 0
|
||||
pos1 = 1
|
||||
elif pos == [1, 0, 1, 0]:
|
||||
tmp = 0x8
|
||||
pos0 = 0
|
||||
pos1 = 2
|
||||
elif pos == [1, 0, 0, 1]:
|
||||
tmp = 0xC
|
||||
pos0 = 0
|
||||
pos1 = 3
|
||||
elif pos == [0, 1, 1, 0]:
|
||||
tmp = 0x9
|
||||
pos0 = 1
|
||||
pos1 = 2
|
||||
elif pos == [0, 1, 0, 1]:
|
||||
tmp = 0xD
|
||||
pos0 = 1
|
||||
pos1 = 3
|
||||
elif pos == [0, 0, 1, 1]:
|
||||
tmp = 0xE
|
||||
pos0 = 2
|
||||
pos1 = 3
|
||||
# cute.printf(row_idx, cutlass.Float32(val), cutlass.Float32(val_1), cutlass.Float32(val_2), cutlass.Float32(val_3), tmp)
|
||||
meta[row_idx, i] |= tmp << (j * 4)
|
||||
|
||||
a_compressed[row_idx, i * 16 + j * 2] = a[
|
||||
row_idx, i * 32 + j * 4 + pos0
|
||||
]
|
||||
a_compressed[row_idx, i * 16 + j * 2 + 1] = a[
|
||||
row_idx, i * 32 + j * 4 + pos1
|
||||
]
|
||||
|
||||
return
|
||||
|
||||
|
||||
# SparseUtils is used to generate sparse tensor
|
||||
# format torch.Tensor
|
||||
class SparseUtils:
|
||||
#!brief: SparseUtils is used to generate sparse tensor
|
||||
#!param: M: int, K: int, L: int, dtype: cutlass.DataType
|
||||
def __init__(self, M: int, K: int, L: int, dtype):
|
||||
self.M = M
|
||||
self.K = K
|
||||
self.L = L
|
||||
self.dtype = dtype
|
||||
self.meta_data = self._generate_meta_data_4_2()
|
||||
self._use_specific_meta_data = False
|
||||
|
||||
#!brief: cast cutlass.DataType to torch.Tensor
|
||||
def _get_type(self):
|
||||
if self.dtype == cutlass.Float16:
|
||||
return torch.float16
|
||||
elif self.dtype == cutlass.Float32:
|
||||
return torch.float32
|
||||
elif self.dtype == cutlass.Int8:
|
||||
return torch.int8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {self.dtype}")
|
||||
|
||||
def _generate_meta_data_4_2(self):
|
||||
# metadata for 4:2 sparse will in range( 4,8,9,c,d,e)
|
||||
# represents
|
||||
# 0: [1,1,0,0] no zero pos 00,01 -> 0100 = 4
|
||||
# 1: [1,0,1,0] no zero pos 00,10 -> 1000 = 8
|
||||
# 2: [1,0,0,1] no zero pos 00,11 -> 1100 = c
|
||||
# 3: [0,1,1,0] no zero pos 01,10 -> 1001 = 9
|
||||
# 4: [0,1,0,1] no zero pos 01,11 -> 1101 = d
|
||||
# 5: [0,0,1,1] no zero pos 10,11 -> 1011 = e
|
||||
meta_value = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
|
||||
# 4:2 sparse, so each chunk is 4 elements, map to 4 bits
|
||||
K_NumChunk = self.K // 4
|
||||
meta_data = np.random.choice(
|
||||
meta_value, size=(self.M, K_NumChunk), replace=True
|
||||
)
|
||||
meta_data = torch.from_numpy(
|
||||
np.array(meta_data).astype(np.uint8).reshape(self.M, K_NumChunk)
|
||||
)
|
||||
return meta_data
|
||||
|
||||
#!brief: pack meta data
|
||||
def _pack_meta_data(self):
|
||||
tmp = []
|
||||
K_NumChunk = self.K // 4
|
||||
for i in range(self.M):
|
||||
for j in range(K_NumChunk // 8):
|
||||
v = 0
|
||||
for k in range(8):
|
||||
vv = int(self.meta_data[i, j * 8 + k] & 0xF)
|
||||
tt = vv << (k * 4)
|
||||
v = v | tt
|
||||
tmp.append(v)
|
||||
# debug print
|
||||
# print([hex(vt) for vt in tmp])
|
||||
result = torch.from_numpy(
|
||||
np.array(tmp).astype(np.uint32).reshape(self.M, K_NumChunk // 8)
|
||||
)
|
||||
return result
|
||||
|
||||
#!brief: use specific meta data
|
||||
def use_specific_meta_data(self, meta_data: torch.Tensor = None):
|
||||
if meta_data is not None:
|
||||
self.meta_data = meta_data
|
||||
self._use_specific_meta_data = True
|
||||
|
||||
#!brief: generate sparse tensor with tensor
|
||||
#!param: a: torch.Tensor
|
||||
#!param: run_on_cpu: bool
|
||||
#!return: torch.Tensor
|
||||
def generate_sparse_4_2_tensor_with_tensor(self, a, run_on_cpu):
|
||||
if run_on_cpu:
|
||||
if a.device.type != "cpu":
|
||||
raise ValueError("a must be on cpu")
|
||||
return self.__generate_sparse_tensor_cpu(a)
|
||||
else:
|
||||
if a.device.type != "cuda":
|
||||
raise ValueError("a must be on cuda")
|
||||
a_tensor = from_dlpack(a)
|
||||
packed_meta_data = self._pack_meta_data()
|
||||
meta_tensor = from_dlpack(packed_meta_data.cuda())
|
||||
self.__generate_sparse_tensor_cuda(a_tensor, meta_tensor)
|
||||
return a
|
||||
|
||||
#!brief: generate sparse tensor
|
||||
#!param: run_on_cpu: bool
|
||||
#!return: torch.Tensor
|
||||
def generate_4_2_sparse_tensor(self, run_on_cpu):
|
||||
dtype = self._get_type()
|
||||
a = torch.empty(self.M, self.K).random_(-5, 5).to(dtype)
|
||||
if run_on_cpu:
|
||||
return self.generate_sparse_4_2_tensor_with_tensor(a, run_on_cpu)
|
||||
else:
|
||||
return self.generate_sparse_4_2_tensor_with_tensor(a.cuda(), run_on_cpu)
|
||||
|
||||
#!brief: generate sparse tensor on cpu
|
||||
#!param: a: torch.Tensor
|
||||
#!return: torch.Tensor
|
||||
def __generate_sparse_tensor_cpu(self, a):
|
||||
if not self._use_specific_meta_data:
|
||||
for m in range(self.M):
|
||||
for k in range(0, self.K, 4):
|
||||
# random choose 2 zero positions
|
||||
zero_indices = torch.randperm(4)[:2]
|
||||
a[m, k + zero_indices[0]] = 0
|
||||
a[m, k + zero_indices[1]] = 0
|
||||
return a
|
||||
else:
|
||||
# use specific meta data
|
||||
tensor_mask = []
|
||||
for i in range(self.M):
|
||||
for j in range(self.K // 4):
|
||||
meta_val = self.meta_data[i, j]
|
||||
tmp = []
|
||||
if meta_val == 0x4:
|
||||
tmp = [1, 1, 0, 0]
|
||||
elif meta_val == 0x8:
|
||||
tmp = [1, 0, 1, 0]
|
||||
elif meta_val == 0xC:
|
||||
tmp = [1, 0, 0, 1]
|
||||
elif meta_val == 0x9:
|
||||
tmp = [0, 1, 1, 0]
|
||||
elif meta_val == 0xD:
|
||||
tmp = [0, 1, 0, 1]
|
||||
elif meta_val == 0xE:
|
||||
tmp = [0, 0, 1, 1]
|
||||
tensor_mask.extend(tmp)
|
||||
a = torch.reshape(a, (-1,))
|
||||
mask = torch.tensor(tensor_mask)
|
||||
a = a * mask
|
||||
a = torch.reshape(a, (self.M, self.K))
|
||||
return a
|
||||
|
||||
@cute.jit
|
||||
def __generate_sparse_tensor_cuda(self, a: cute.Tensor, meta: cute.Tensor):
|
||||
"""Generate a sparse tensor from a dense tensor using metadata"""
|
||||
assert a.shape[0] == self.M and a.shape[1] == self.K
|
||||
assert meta.shape[0] == self.M and meta.shape[1] == self.K // 4 // 8
|
||||
num_threads = 128
|
||||
grid = (cute.ceil_div(self.M, num_threads), 1, 1)
|
||||
block = (num_threads, 1, 1)
|
||||
self.kernel(a, meta).launch(grid=grid, block=block)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(self, a: cute.Tensor, meta: cute.Tensor):
|
||||
"""Apply sparsity mask to input tensor using metadata"""
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
|
||||
# each thread process 1 ro
|
||||
row_idx = tidx + bidx * self.M
|
||||
meta_idx = self.K // 4 // 8
|
||||
# each thread process 1 row
|
||||
if row_idx < self.M:
|
||||
# iterate over each chunk(32 elements)
|
||||
for i in range(meta_idx):
|
||||
meta_val = meta[(row_idx, i)]
|
||||
# iterate over each sparse pattern(4 elements)
|
||||
for j in range(8):
|
||||
meta_row = (meta_val >> (j * 4)) & 0xF
|
||||
idx0 = meta_row & 0x3
|
||||
idx1 = (meta_row >> 2) & 0x3
|
||||
r_id0 = 0
|
||||
r_id1 = 0
|
||||
# r_id is the idx that value is 0
|
||||
if idx0 >= 2 and idx1 >= 2:
|
||||
r_id0 = 0
|
||||
r_id1 = 1
|
||||
elif idx0 <= 1 and idx1 <= 1:
|
||||
r_id0 = 2
|
||||
r_id1 = 3
|
||||
else:
|
||||
r_id0 = idx0 ^ 0b1
|
||||
r_id1 = idx1 ^ 0b1
|
||||
row_id0 = r_id0 + i * 32 + j * 4
|
||||
row_id1 = r_id1 + i * 32 + j * 4
|
||||
a[row_idx, row_id0] = self.dtype(0.0)
|
||||
a[row_idx, row_id1] = self.dtype(0.0)
|
||||
return
|
||||
104
examples/python/CuTeDSL/utils/test_sparse_utils.py
Normal file
104
examples/python/CuTeDSL/utils/test_sparse_utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
import sparse_utils as su
|
||||
import cutlass
|
||||
import torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
def test_sparse_cpu():
|
||||
M = 128
|
||||
N = 32
|
||||
K = 32
|
||||
L = 1
|
||||
debug = False
|
||||
# generate sparse tensor
|
||||
a = torch.empty(M, K).random_(-5, 5).to(torch.float16)
|
||||
sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16)
|
||||
if debug:
|
||||
sparse_utils.use_specific_meta_data()
|
||||
a_gen_from_cpu = sparse_utils.generate_sparse_4_2_tensor_with_tensor(a, True)
|
||||
# print(a_gen_from_cpu)
|
||||
# generate compressed tensor and meta data
|
||||
a_compressed_cpu = torch.empty(M, K // 2).to(torch.float16)
|
||||
meta_data_cpu = torch.empty(M, K // 4 // 8).to(torch.uint32)
|
||||
compressor = su.Compressor(M, K, L)
|
||||
compressor.compress(a_gen_from_cpu, a_compressed_cpu, meta_data_cpu, True)
|
||||
# # test with gemm
|
||||
b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda()
|
||||
d = torch.empty(M, N).zero_().to(torch.float16).cuda()
|
||||
b_tensor = from_dlpack(b)
|
||||
d_tensor = from_dlpack(d)
|
||||
a_compressed_cpu_tensor = from_dlpack(a_compressed_cpu.cuda())
|
||||
meta_data_cpu_tensor = from_dlpack(meta_data_cpu.cuda())
|
||||
sparse_emulation = su.SparseEmulation(M, N, K, 1)
|
||||
sparse_emulation(a_compressed_cpu_tensor, b_tensor, d_tensor, meta_data_cpu_tensor)
|
||||
|
||||
ref = torch.einsum("mk,nk->mn", a_gen_from_cpu.cpu(), b.cpu())
|
||||
if debug:
|
||||
a_ori = a_gen_from_cpu.cpu().numpy()
|
||||
np.savetxt("a.txt", a_ori, fmt="%f")
|
||||
a_compressed_cpu_ori = a_compressed_cpu.cpu().numpy()
|
||||
np.savetxt("a_compressed_cpu.txt", a_compressed_cpu_ori, fmt="%f")
|
||||
meta_data_cpu_ori = meta_data_cpu.cpu().numpy()
|
||||
np.savetxt("meta_data_cpu.txt", meta_data_cpu_ori, fmt="%f")
|
||||
d_ori = d.cpu().numpy()
|
||||
np.savetxt("d.txt", d_ori, fmt="%f")
|
||||
ref_ori = ref.cpu().numpy()
|
||||
np.savetxt("ref.txt", ref_ori, fmt="%f")
|
||||
torch.testing.assert_close(d.cpu(), ref)
|
||||
print("cpu d == ref")
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
def test_sparse_cuda():
|
||||
M = 128
|
||||
N = 32
|
||||
K = 32
|
||||
L = 1
|
||||
debug = False
|
||||
sparse_utils = su.SparseUtils(M, K, L, cutlass.Float16)
|
||||
if debug:
|
||||
sparse_utils.use_specific_meta_data()
|
||||
# generate sparse tensor
|
||||
a = torch.empty(M, K).random_(-5, 5).to(torch.float16).cuda()
|
||||
a_gen_from_cuda = sparse_utils.generate_4_2_sparse_tensor(False)
|
||||
# print(a_gen_from_cuda)
|
||||
# generate compressed tensor and meta data
|
||||
a_compressed_cuda = torch.empty(M, K // 2).to(torch.float16).cuda()
|
||||
meta_data_cuda = torch.empty(M, K // 4 // 8).to(torch.uint32).cuda()
|
||||
compressor = su.Compressor(M, K, L)
|
||||
compressor.compress(a_gen_from_cuda, a_compressed_cuda, meta_data_cuda, False)
|
||||
# test with gemm
|
||||
b = torch.empty(N, K).random_(-5, 5).to(torch.float16).cuda()
|
||||
d = torch.empty(M, N).zero_().to(torch.float16).cuda()
|
||||
b_tensor = from_dlpack(b)
|
||||
d_tensor = from_dlpack(d)
|
||||
a_compressed_cuda_tensor = from_dlpack(a_compressed_cuda)
|
||||
meta_data_cuda_tensor = from_dlpack(meta_data_cuda)
|
||||
sparse_emulation = su.SparseEmulation(M, N, K, 1)
|
||||
sparse_emulation(
|
||||
a_compressed_cuda_tensor, b_tensor, d_tensor, meta_data_cuda_tensor
|
||||
)
|
||||
|
||||
ref = torch.einsum("mk,nk->mn", a_gen_from_cuda.cpu(), b.cpu())
|
||||
if debug:
|
||||
a_ori = a_gen_from_cuda.cpu().numpy()
|
||||
np.savetxt("a.txt", a_ori, fmt="%f")
|
||||
a_compressed_cuda_ori = a_compressed_cuda.cpu().numpy()
|
||||
np.savetxt("a_compressed_cuda.txt", a_compressed_cuda_ori, fmt="%f")
|
||||
meta_data_cuda_ori = meta_data_cuda.cpu().numpy()
|
||||
np.savetxt("meta_data_cuda.txt", meta_data_cuda_ori, fmt="%f")
|
||||
d_ori = d.cpu().numpy()
|
||||
np.savetxt("d.txt", d_ori, fmt="%f")
|
||||
ref_ori = ref.cpu().numpy()
|
||||
np.savetxt("ref.txt", ref_ori, fmt="%f")
|
||||
torch.testing.assert_close(d.cpu(), ref)
|
||||
print("cuda d == ref")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cutlass.cuda.initialize_cuda_context()
|
||||
test_sparse_cpu()
|
||||
test_sparse_cuda()
|
||||
Reference in New Issue
Block a user