Files
cutlass/examples/python/CuTeDSL/utils/fmha_helpers.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

976 lines
36 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from typing import Tuple, Optional
import cutlass
from cutlass.cute.typing import Boolean
from cutlass.cutlass_dsl import (
Int32,
Float32,
min,
extract_mlir_values,
new_from_mlir_values,
)
from cutlass.utils.hardware_info import HardwareInfo
from cutlass.utils import WorkTileInfo
import cutlass.cute as cute
##############################################################################
# Fmha static tile scheduler
##############################################################################
class FmhaStaticTileSchedulerParams:
"""A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler.
This class holds the configuration parameters needed to initialize and configure
the tile scheduler for FMHA operations.
:ivar is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:ivar problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
"""
def __init__(
self,
is_persistent: bool,
problem_shape_mbh: cute.Shape,
*,
loc=None,
ip=None,
):
"""
Initializes the FmhaStaticTileSchedulerParams with the given parameters.
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:param problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
"""
self.is_persistent = is_persistent
self.problem_shape_mbh = problem_shape_mbh
self._loc = loc
self._ip = ip
def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [self.problem_shape_mbh]:
obj_values = extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values
def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip([self.problem_shape_mbh], self._values_pos):
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return FmhaStaticTileSchedulerParams(
self.is_persistent, *(tuple(obj_list)), loc=self._loc
)
class FmhaStaticTileScheduler:
"""A static tile scheduler for FMHA (Fused Multi-Head Attention) operations.
This class manages the scheduling of work tiles for FMHA kernels, supporting
both persistent and non-persistent kernel modes. It tracks the current work
position and advances through the problem space efficiently.
:ivar _params: Scheduler parameters.
:type _params: FmhaStaticTileSchedulerParams
:ivar _blk_coord: Block coordinates.
:type _blk_coord: cute.Coord
:ivar _grid_shape: Grid shape for the kernel.
:type _grid_shape: cute.Shape
:ivar _is_persistent: Whether to use persistent kernel mode.
:type _is_persistent: bool
:ivar _current_work_linear_idx: Current linear work index.
:type _current_work_linear_idx: Int32
:ivar _problem_shape_mbh: Problem shape in (M, B, H) format.
:type _problem_shape_mbh: cute.Layout
:ivar _num_blocks: Number of blocks in the problem.
:type _num_blocks: Int32
:ivar _is_first_block: Whether this is the first block.
:type _is_first_block: bool
:ivar num_persistent_sm: Number of persistent SMs.
:type num_persistent_sm: Int32
"""
def __init__(
self,
params: FmhaStaticTileSchedulerParams,
current_work_linear_idx: Int32,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
*,
loc=None,
ip=None,
):
"""
Initializes the FmhaStaticTileScheduler with the given parameters.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:param current_work_linear_idx: Current linear work index.
:type current_work_linear_idx: Int32
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param grid_shape: Grid shape for the kernel.
:type grid_shape: cute.Shape
"""
self._params = params
self._blk_coord = blk_coord
self._grid_shape = grid_shape
self._is_persistent = params.is_persistent
self._current_work_linear_idx = current_work_linear_idx
self._problem_shape_mbh = cute.make_layout(
params.problem_shape_mbh, loc=loc, ip=ip
)
self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip)
self._is_first_block = True
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
self._loc = loc
self._ip = ip
# called by host
@staticmethod
def get_grid_shape(
params: FmhaStaticTileSchedulerParams,
*,
loc=None,
ip=None,
) -> cute.Shape:
"""
Determine the grid shape for the FMHA kernel.
For persistent kernels, the grid shape is limited by the number of SMs
(Streaming Multiprocessors) available on the device. For non-persistent
kernels, the grid shape matches the problem shape.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:return: Grid shape as (M, B, H) tuple.
:rtype: cute.Shape
"""
if params.is_persistent:
hardware_info = HardwareInfo()
sm_count = hardware_info.get_device_multiprocessor_count()
return (
min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)),
1,
1,
)
else:
return params.problem_shape_mbh
@staticmethod
def check_valid_work_for_seqlen_q(
q_tiler: int,
current_idx: Int32,
seqlen_q: Int32,
) -> Boolean:
"""
Check if the current work index is valid for the given query sequence length.
This method verifies that the current work tile index multiplied by the
query tiler size is within the bounds of the query sequence length.
:param q_tiler: Query tiler size.
:type q_tiler: int
:param current_idx: Current work index.
:type current_idx: Int32
:param seqlen_q: Query sequence length.
:type seqlen_q: Int32
:return: True if the work is valid, False otherwise.
:rtype: Boolean
"""
return current_idx * q_tiler < seqlen_q
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
"""
Get information about the current work tile.
Determines if the current work is valid and computes the tile coordinates
based on whether the kernel is persistent or non-persistent.
:return: WorkTileInfo containing tile coordinates and validity flag.
:rtype: WorkTileInfo
"""
is_valid = (
self._current_work_linear_idx < self._num_blocks
if self._is_persistent
else self._is_first_block
)
blk_coord = (0, 0, 0)
if self._is_persistent:
blk_coord = self._problem_shape_mbh.get_hier_coord(
self._current_work_linear_idx, loc=loc, ip=ip
)
else:
blk_coord = self._blk_coord
# cur_tile_coord is (mid, 0, (bid, hid))
cur_tile_coord = (
blk_coord[0],
0,
(blk_coord[1], blk_coord[2]),
)
return WorkTileInfo(cur_tile_coord, is_valid)
def initial_work_tile_info(self, *, loc=None, ip=None):
"""
Get the initial work tile information.
:return: Initial WorkTileInfo.
:rtype: WorkTileInfo
"""
return self.get_current_work(loc=loc, ip=ip)
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
"""
Advance to the next work tile.
For persistent kernels, advances by the number of persistent SMs.
For non-persistent kernels, marks that the first block has been processed.
:param advance_count: Number of steps to advance (default: 1).
:type advance_count: int
"""
if self._is_persistent:
self._current_work_linear_idx += advance_count * self.num_persistent_sm
self._is_first_block = False
def __extract_mlir_values__(self):
values = extract_mlir_values(self._params)
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self._blk_coord))
values.extend(extract_mlir_values(self._grid_shape))
return values
def __new_from_mlir_values__(self, values):
assert len(values) == 10
new_params = new_from_mlir_values(self._params, values[0:3])
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[3]]
)
new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7])
new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:])
return FmhaStaticTileScheduler(
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
)
def create_fmha_static_tile_scheduler(
params: FmhaStaticTileSchedulerParams,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
) -> FmhaStaticTileScheduler:
"""
Create a new FMHA static tile scheduler.
:param params: Scheduler parameters.
:type params: FmhaStaticTileSchedulerParams
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param grid_shape: Grid shape.
:type grid_shape: cute.Shape
:return: New FmhaStaticTileScheduler instance.
:rtype: FmhaStaticTileScheduler
"""
return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
def create_fmha_static_tile_scheduler_params(
is_persistent: bool,
problem_shape_mbh: cute.Shape,
) -> FmhaStaticTileSchedulerParams:
"""
Create FMHA static tile scheduler parameters.
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:param problem_shape_mbh: Problem shape in (M, B, H) format.
:type problem_shape_mbh: cute.Shape
:return: New FmhaStaticTileSchedulerParams instance.
:rtype: FmhaStaticTileSchedulerParams
"""
return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh)
def compute_grid(
o_shape: cute.Shape,
cta_tiler: Tuple[int, int, int],
is_persistent: bool,
) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]:
"""
Compute grid parameters for FMHA operation.
This function calculates the appropriate grid shape and scheduler parameters
based on the output tensor shape, CTA (Cooperative Thread Array) tiler,
and whether to use persistent kernel mode.
The output tensor o has shape (s, d, ((h_r, h_k), b)) where:
- s: sequence length
- d: head dimension
- h_r: number of heads for query
- h_k: number of heads for key
- b: batch size
:param o_shape: Output tensor shape for grid computation.
:type o_shape: cute.Shape
:param cta_tiler: CTA tiler dimensions (M, N, K).
:type cta_tiler: Tuple[int, int, int]
:param is_persistent: Whether to use persistent kernel mode.
:type is_persistent: bool
:return: Tuple of (scheduler_params, grid_shape).
:rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]
"""
tile_sched_params = create_fmha_static_tile_scheduler_params(
is_persistent,
(
cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]),
cute.size(o_shape[2][0]),
cute.size(o_shape[2][1]),
),
)
grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params)
return tile_sched_params, grid
##############################################################################
# Fused Mask
##############################################################################
class MaskEnum(enum.Enum):
"""Enumeration of mask types for FMHA operations.
- RESIDUAL_MASK: Residual mask for handling variable sequence lengths
- WINDOW_MASK: Window mask for attention which also includes causal and no mask
- WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k
- WINDOW_MASK_BWD: Window mask for backward pass
- WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k
"""
RESIDUAL_MASK = enum.auto()
RESIDUAL_MASK_BWD = enum.auto()
WINDOW_MASK = enum.auto()
WINDOW_MASK_INFERENCE = enum.auto()
WINDOW_MASK_BWD = enum.auto()
WINDOW_MASK_BWD_INFERENCE = enum.auto()
class FusedMask:
"""A fused mask implementation for FMHA operations.
This class handles different types of attention masks including no mask,
residual mask for variable sequence lengths, and causal mask for
autoregressive attention patterns.
The class provides methods to:
- Calculate trip counts for different mask types
- Apply masks to attention scores
- Handle masked and unmasked trip calculations
"""
def get_trip_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of trips needed for the current block.
The trip count depends on the mask type and the block coordinates.
For causal masks, it considers the autoregressive constraint.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of trips needed.
:rtype: Int32
"""
result = 0
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
if cutlass.const_expr(mask_type == MaskEnum.RESIDUAL_MASK):
result = cute.ceil_div(seqlen_k, tile_shape[1])
if cutlass.const_expr(mask_type is MaskEnum.RESIDUAL_MASK_BWD):
result = cute.ceil_div(seqlen_q, tile_shape[0])
if cutlass.const_expr(
mask_type == MaskEnum.WINDOW_MASK
or mask_type == MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_right is None):
result = cute.ceil_div(seqlen_k, tile_shape[1])
else:
max_idx_q = (blk_coord[0] + 1) * tile_shape[0]
idx_k = max_idx_q + offset + window_size_right
tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1])
max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
result = min(max_blocks_k, tmp_blocks_k)
if cutlass.const_expr(
mask_type == MaskEnum.WINDOW_MASK_BWD
or mask_type == MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_left is None):
result = cute.ceil_div(seqlen_q, tile_shape[0])
else:
max_idx_k = (blk_coord[1] + 1) * tile_shape[1]
idx_k = max_idx_k + offset + window_size_left
tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0])
max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0])
result = min(max_blocks_q, tmp_blocks_q)
start_block = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
result = result - start_block
return result
@cute.jit
def get_trip_start(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Get the start of the trip for the current block.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
"""
result = 0
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_left is not None):
min_idx_q = blk_coord[0] * tile_shape[0]
idx_k = min_idx_q + offset - window_size_left
tmp_blocks_k = idx_k // tile_shape[1]
result = max(tmp_blocks_k, result)
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_BWD
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_k = blk_coord[1] * tile_shape[1]
idx_q = min_idx_k + offset - window_size_right
tmp_blocks_q = idx_q // tile_shape[0]
result = max(tmp_blocks_q, result)
return result
@cute.jit
def get_leading_mask_id(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Tuple[Int32, Int32]:
"""
Get the begin and end tile idx for the leading mask.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Tuple of (begin, end) tile idx for the leading mask.
:rtype: Tuple[Int32, Int32]
"""
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
leading_mask_begin = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trip_count = FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
leading_mask_end = leading_mask_begin
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_left is not None):
min_idx_q = (
(blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left
)
leading_mask_end = min(
cute.ceil_div(min_idx_q, tile_shape[1]) - 1,
trip_count + leading_mask_begin - 1,
)
else:
leading_mask_end = leading_mask_begin - 1
elif cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_BWD
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_k = (
(blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right
)
leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1
else:
leading_mask_end = leading_mask_begin - 1
return leading_mask_begin, leading_mask_end
@cute.jit
def get_trailing_mask_id(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Tuple[Optional[Int32], Optional[Int32]]:
"""
Get the begin and end tile idx for the trailing mask.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Tuple of (begin, end) tile idx for the trailing mask.
:rtype: Tuple[Int32, Int32]
"""
offset = 0
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE):
offset = seqlen_k - seqlen_q
if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE):
offset = seqlen_q - seqlen_k
trip_start = FusedMask.get_trip_start(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trip_count = FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
trailing_mask_begin, trailing_mask_end = None, None
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK
or mask_type is MaskEnum.WINDOW_MASK_INFERENCE
):
if cutlass.const_expr(window_size_right is not None):
min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right
trailing_mask_begin = min(
min_idx_q // tile_shape[1], trip_count + trip_start - 1
)
trailing_mask_end = trip_count + trip_start - 1
else:
# last tile, we always apply mask on it regardless whether it's a residual tile
trailing_mask_begin = trip_count + trip_start - 1
trailing_mask_end = trip_count + trip_start - 1
else:
if cutlass.const_expr(window_size_left is not None):
min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1
max_idx_k = (
(blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left
)
trailing_mask_begin = min(
cute.ceil_div(min_idx_k, tile_shape[0]) - 1,
trip_count + trip_start - 1,
)
trailing_mask_end = min(
cute.ceil_div(max_idx_k, tile_shape[0]) - 1,
trip_count + trip_start - 1,
)
else:
# last tile, we always apply mask on it regardless whether it's a residual tile
trailing_mask_begin = trip_count + trip_start - 1
trailing_mask_end = trip_count + trip_start - 1
return trailing_mask_begin, trailing_mask_end
@cute.jit
def get_masked_leading_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of masked trips for the leading mask.
This is used for blocks that need special handling due to masking.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of masked trips.
:rtype: Int32
"""
result = 0
if cutlass.const_expr(
mask_type is not MaskEnum.RESIDUAL_MASK
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
):
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
result = max(leading_mask_end - leading_mask_begin + 1, 0)
return result
@cute.jit
def get_masked_trailing_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
rem_count: Optional[Int32] = 0,
) -> Int32:
"""
Calculate the number of masked trips for the trailing mask.
This is used for blocks that need special handling due to masking.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:param rem_count: Remaining count from previous calculations.
:type rem_count: Int32
:return: Number of masked trips.
:rtype: Int32
"""
result = 0
if cutlass.const_expr(
mask_type is not MaskEnum.RESIDUAL_MASK
and mask_type is not MaskEnum.RESIDUAL_MASK_BWD
):
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
if cutlass.const_expr(
trailing_mask_begin is not None and trailing_mask_end is not None
):
if trailing_mask_begin <= leading_mask_end:
result = max(trailing_mask_end - leading_mask_end, 0)
else:
result = max(trailing_mask_end - trailing_mask_begin + 1, 0)
else:
if seqlen_k % tile_shape[1] != 0:
result = 1
else:
result = 0
return result + rem_count
@cute.jit
def get_unmasked_trip_count(
mask_type: MaskEnum,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[Int32] = None,
window_size_right: Optional[Int32] = None,
) -> Int32:
"""
Calculate the number of unmasked trips for the current block.
This represents the number of trips that don't require special
masking treatment.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param blk_coord: Block coordinates.
:type blk_coord: cute.Coord
:param tile_shape: Shape of the tile.
:type tile_shape: cute.Shape
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Int32
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[Int32]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[Int32]
:return: Number of unmasked trips.
:rtype: Int32
"""
result = (
FusedMask.get_trip_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
- FusedMask.get_masked_leading_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
)
- FusedMask.get_masked_trailing_count(
mask_type,
blk_coord,
tile_shape,
seqlen_q,
seqlen_k,
window_size_left,
window_size_right,
0,
)
)
return result
@cute.jit
def apply_mask(
mask_type: MaskEnum,
acc_qk: cute.Tensor,
index_qk: cute.Tensor,
seqlen_q: Int32,
seqlen_k: Int32,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
index_transform: cutlass.Constexpr = lambda index_q, index_k: (
index_q,
index_k,
),
):
"""
Apply the appropriate mask to the attention scores.
This method modifies the attention scores (acc_qk) based on the mask type
and the positions in the index tensor.
:param mask_type: Type of mask to use
:type mask_type: utils.MaskEnum
:param acc_qk: Accumulated QK attention scores tensor.
:type acc_qk: cute.Tensor
:param index_qk: Index tensor containing position information.
:type index_qk: cute.Tensor
:param seqlen_k: Key sequence length for attention computation.
:type seqlen_k: Int32
:param seqlen_q: Query sequence length for attention computation.
:type seqlen_q: Optional[int]
:param window_size_left: Left-side sliding window size for attention masking.
:type window_size_left: Optional[int]
:param window_size_right: Right-side sliding window size for attention masking.
:type window_size_right: Optional[int]
"""
tidx, tidy, tidx = cute.arch.thread_idx()
offset = 0
offset = (
seqlen_k - seqlen_q
if cutlass.const_expr(
mask_type is MaskEnum.WINDOW_MASK_INFERENCE
or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE
)
else 0
)
for i in cutlass.range_constexpr(cute.size(acc_qk)):
index_q, index_k = index_transform(*index_qk[i])
if cutlass.const_expr(
window_size_left is not None or window_size_right is not None
):
if cutlass.const_expr(window_size_left is None):
if index_q + offset + window_size_right < index_k:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
elif cutlass.const_expr(window_size_right is None):
if index_q + offset - window_size_left > index_k:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
else:
max_K_index = min(index_q + offset + window_size_right, seqlen_k)
min_K_index = max(0, index_q + offset - window_size_left)
if index_k > max_K_index or index_k < min_K_index:
acc_qk[i] = -Float32.inf
if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask
acc_qk[i] = -Float32.inf
if cutlass.const_expr(
mask_type == MaskEnum.RESIDUAL_MASK
or mask_type == MaskEnum.RESIDUAL_MASK_BWD
):
if index_k >= seqlen_k or index_q >= seqlen_q:
acc_qk[i] = -Float32.inf