* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
2538 lines
91 KiB
Python
2538 lines
91 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
"""
|
|
A fused multi-head attention (FMHA) example for the NVIDIA Hopper SM90 architecture using CUTE DSL
|
|
|
|
This example demonstrates an implementation of fused multi-head attention using a TMA + Hopper SM90
|
|
TensorCore warp-specialized kernel. The implementation integrates the Q*K^T matrix multiplication,
|
|
softmax normalization, and softmax(Q*K^T)*V into a single kernel, avoiding intermediate data movement between
|
|
global memory and shared memory, thus improving computational efficiency.
|
|
|
|
The kernel implements key optimizations including:
|
|
- Warp specialization for different computation phases (load, MMA)
|
|
- 2 MMA WarpGroup for compute
|
|
- Pipeline stages between different warps for overlapping computation and memory access
|
|
- Support for different precision data types
|
|
- Optional causal masking for autoregressive models
|
|
- Sliding window attention masking for efficient long sequence processing
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/hopper/fmha.py \
|
|
--qk_acc_dtype Float32 --pv_acc_dtype Float32 \
|
|
--mma_tiler_mn 64,128 \
|
|
--q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \
|
|
--is_persistent
|
|
|
|
The above example runs FMHA with batch size 4, sequence length 1024, 8 attention heads, and head
|
|
dimension 64. The Hopper MMA tile shape is (64, 128), and the kernel uses fp16 for input/output
|
|
with fp32 for accumulation.
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/hopper/fmha.py \
|
|
--qk_acc_dtype Float32 --pv_acc_dtype Float32 \
|
|
--mma_tiler_mn 64,128 \
|
|
--q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \
|
|
--is_persistent --warmup_iterations 10 \
|
|
--iterations 10 --skip_ref_check
|
|
|
|
Constraints for this example:
|
|
* Supported head dimensions: 32, 64, 128, 256
|
|
** 256 for `mma_tiler_mn` shoule be (64, 256) with non-persistent mode(not present `--is_persistent` in command line
|
|
* Number of heads in Q must equal with number of heads in K
|
|
* For causal masking, use --is_causal (note: specify without =True/False)
|
|
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
|
|
* For sliding window, use --window_size x,y where x is left window size and y is right window size
|
|
"""
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
from typing import Type, Tuple, Optional
|
|
|
|
import torch
|
|
|
|
|
|
import cuda.bindings.driver as cuda
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.cute.nvgpu.warpgroup as warpgroup
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
import cutlass.torch as cutlass_torch
|
|
from cutlass._mlir.dialects import math as _math
|
|
|
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.insert(0, os.path.join(current_dir, ".."))
|
|
from utils import fmha_helpers as fmha_utils
|
|
|
|
|
|
class HopperFusedMultiHeadAttentionForward:
|
|
def __init__(
|
|
self,
|
|
qk_acc_dtype,
|
|
pv_acc_dtype,
|
|
mma_tiler,
|
|
is_persistent,
|
|
mask_type: fmha_utils.MaskEnum,
|
|
):
|
|
"""Initializes the configuration for a Hopper Fused Multi-Head Attention (FMHA) kernel.
|
|
|
|
This configuration includes several key aspects:
|
|
|
|
1. Data Type Settings:
|
|
- qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator
|
|
- pv_acc_dtype: Data type for P*V matrix multiplication accumulator
|
|
|
|
2. MMA Instruction Settings:
|
|
- mma_tiler: The (M, N, K) shape of the MMA instruction unit
|
|
- qk_mma_tiler: MMA shape for Q*K^T computation
|
|
- pv_mma_tiler: MMA shape for P*V computation
|
|
|
|
3. Kernel Execution Mode:
|
|
- is_persistent: Boolean indicating whether to use persistent kernel mode
|
|
- mask_type: Specifies the type of mask to use (no mask, residual mask, or causal mask)
|
|
- window_size_left/right: Sliding window parameters for attention masking
|
|
|
|
:param qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator
|
|
:type qk_acc_dtype: Type[cutlass.Numeric]
|
|
:param pv_acc_dtype: Data type for P*V matrix multiplication accumulator
|
|
:type pv_acc_dtype: Type[cutlass.Numeric]
|
|
:param mma_tiler: The (M, N, K) shape of the MMA instruction
|
|
:type mma_tiler: Tuple[int, int, int]
|
|
:param is_persistent: Whether to use persistent kernel mode
|
|
:type is_persistent: bool
|
|
:param mask_type: Type of mask to use
|
|
:type mask_type: fmha_utils.MaskEnum
|
|
"""
|
|
|
|
self.num_mma_warp_groups = 2
|
|
self.qk_acc_dtype = qk_acc_dtype
|
|
self.pv_acc_dtype = pv_acc_dtype
|
|
self.cta_tiler = self.cta_tile_shape_mnk = (
|
|
mma_tiler[0] * self.num_mma_warp_groups,
|
|
mma_tiler[1],
|
|
mma_tiler[2],
|
|
)
|
|
|
|
self.qk_mma_tiler = (
|
|
mma_tiler[0],
|
|
mma_tiler[1],
|
|
mma_tiler[2],
|
|
)
|
|
|
|
self.pv_mma_tiler = (
|
|
self.qk_mma_tiler[0],
|
|
self.qk_mma_tiler[2],
|
|
self.qk_mma_tiler[1],
|
|
)
|
|
|
|
self.cluster_shape_mn = (1, 1)
|
|
self.atom_layout_mnk = (1, 1, 1)
|
|
self.is_persistent = is_persistent
|
|
self.mask_type = mask_type
|
|
self.threads_per_warp = 32
|
|
self.num_threads_per_warp_group = 128
|
|
self.num_warps_per_warp_group = (
|
|
self.num_threads_per_warp_group / self.threads_per_warp
|
|
)
|
|
|
|
# WarpGroupRole
|
|
self.load_warp_group_id = 0
|
|
self.compute_epilogue_0_warp_group_id = 1
|
|
self.compute_epilogue_1_warp_group_id = 2
|
|
# ProducerWarpRole
|
|
self.producer_warp_loadkv_id = 1
|
|
|
|
self.num_regs_load = 40 - 2 * 8
|
|
num_load_warp_groups = 1
|
|
self.num_threads_per_warp_group = 128
|
|
max_threads_per_block = (
|
|
self.num_mma_warp_groups + num_load_warp_groups
|
|
) * self.num_threads_per_warp_group
|
|
self.threads_per_cta = max_threads_per_block
|
|
self.num_regs_mma = 240
|
|
self.buffer_align_bytes = 1024
|
|
|
|
def _setup_attributes(self):
|
|
self.q_stage = 2
|
|
self.kv_stage = 5
|
|
self.epi_stage = 2
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
q: cute.Tensor,
|
|
k: cute.Tensor,
|
|
v: cute.Tensor,
|
|
o: cute.Tensor,
|
|
lse: cute.Tensor,
|
|
scale_softmax_log2: cutlass.Float32,
|
|
scale_softmax: cutlass.Float32,
|
|
scale_output: cutlass.Float32,
|
|
window_size_left: Optional[cutlass.Int32],
|
|
window_size_right: Optional[cutlass.Int32],
|
|
stream: cuda.CUstream,
|
|
):
|
|
# setup static attributes before smem/grid/tma computation
|
|
self.q_dtype = q.element_type
|
|
self.k_dtype = k.element_type
|
|
self.v_dtype = v.element_type
|
|
self.o_dtype = o.element_type
|
|
|
|
# (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast
|
|
k = cute.make_tensor(
|
|
k.iterator,
|
|
cute.make_layout(
|
|
(k.shape[0], k.shape[1], ((q.shape[2], k.shape[3]), k.shape[4])),
|
|
stride=(
|
|
k.stride[0],
|
|
k.stride[1],
|
|
((0, k.stride[3]), k.stride[4]),
|
|
),
|
|
),
|
|
)
|
|
|
|
# (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast
|
|
v = cute.make_tensor(
|
|
v.iterator,
|
|
cute.make_layout(
|
|
(v.shape[1], v.shape[0], ((q.shape[2], v.shape[3]), v.shape[4])),
|
|
stride=(
|
|
v.stride[1],
|
|
v.stride[0],
|
|
((0, v.stride[3]), v.stride[4]),
|
|
),
|
|
),
|
|
)
|
|
|
|
# (s, d, ((h_r, h_k), b))
|
|
q = cute.group_modes(cute.group_modes(q, begin=2, end=4), begin=2, end=4)
|
|
o = cute.group_modes(cute.group_modes(o, begin=2, end=4), begin=2, end=4)
|
|
|
|
# (s, ((h_r, h_k), b))
|
|
lse = cute.make_tensor(
|
|
lse.iterator,
|
|
cute.make_layout(
|
|
(
|
|
lse.shape[0],
|
|
self.pv_mma_tiler[1],
|
|
((lse.shape[2], lse.shape[3]), lse.shape[4]),
|
|
),
|
|
stride=(
|
|
lse.stride[0],
|
|
0,
|
|
((lse.stride[2], lse.stride[3]), lse.stride[4]),
|
|
),
|
|
),
|
|
)
|
|
|
|
if cutlass.const_expr(self.q_dtype != self.k_dtype):
|
|
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}")
|
|
if cutlass.const_expr(self.q_dtype != self.v_dtype):
|
|
raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
|
|
|
|
if cutlass.const_expr(q.leading_dim != 1): # k-major
|
|
raise RuntimeError("The layout of q is not supported")
|
|
|
|
if cutlass.const_expr(k.leading_dim != 1): # k-major
|
|
raise RuntimeError("The layout of k is not supported")
|
|
|
|
self._setup_attributes()
|
|
|
|
tile_shape_mnk = self.cta_tiler
|
|
self.epi_tile = sm90_utils.compute_tile_shape_or_override(
|
|
tile_shape_mnk, self.o_dtype
|
|
)
|
|
|
|
self.q_layout = utils.LayoutEnum.from_tensor(q)
|
|
self.k_layout = utils.LayoutEnum.from_tensor(k)
|
|
self.v_layout = utils.LayoutEnum.from_tensor(v)
|
|
self.o_layout = utils.LayoutEnum.from_tensor(o)
|
|
|
|
self.q_major_mode = self.q_layout.sm90_mma_major_mode()
|
|
self.k_major_mode = self.k_layout.sm90_mma_major_mode()
|
|
self.v_major_mode = self.v_layout.sm90_mma_major_mode()
|
|
|
|
p_major_mode = cute.nvgpu.warpgroup.OperandMajorMode.K
|
|
qk_tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
self.q_dtype,
|
|
self.k_dtype,
|
|
self.q_major_mode,
|
|
self.k_major_mode,
|
|
self.qk_acc_dtype,
|
|
self.atom_layout_mnk,
|
|
self.qk_mma_tiler[:2],
|
|
)
|
|
|
|
pv_tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
|
self.v_dtype,
|
|
self.v_dtype,
|
|
p_major_mode,
|
|
self.v_major_mode,
|
|
self.pv_acc_dtype,
|
|
self.atom_layout_mnk,
|
|
self.pv_mma_tiler[:2],
|
|
warpgroup.OperandSource.RMEM,
|
|
)
|
|
|
|
self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
|
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout(self.cluster_shape_mnk),
|
|
(qk_tiled_mma.thr_id.shape,),
|
|
)
|
|
|
|
q_smem_layout_staged = sm90_utils.make_smem_layout_a(
|
|
self.q_layout,
|
|
self.qk_mma_tiler,
|
|
self.q_dtype,
|
|
self.q_stage,
|
|
)
|
|
|
|
k_smem_layout_staged = sm90_utils.make_smem_layout_b(
|
|
self.k_layout,
|
|
self.qk_mma_tiler,
|
|
self.k_dtype,
|
|
self.kv_stage,
|
|
)
|
|
|
|
v_smem_layout_staged = sm90_utils.make_smem_layout_b(
|
|
self.v_layout,
|
|
self.pv_mma_tiler,
|
|
self.v_dtype,
|
|
self.kv_stage,
|
|
)
|
|
|
|
o_smem_layout_staged = sm90_utils.make_smem_layout_epi(
|
|
self.o_dtype,
|
|
self.o_layout,
|
|
self.epi_tile,
|
|
self.epi_stage,
|
|
cute.append(
|
|
cute.append(self.epi_tile, self.epi_stage), self.num_mma_warp_groups
|
|
),
|
|
smem_order=(1, 0, 2, 3) if self.o_layout.is_m_major_c() else (0, 1, 2, 3),
|
|
)
|
|
|
|
# TMA load for Q
|
|
q_smem_layout = cute.slice_(q_smem_layout_staged, (None, None, 0))
|
|
tma_atom_q, tma_tensor_q = self._make_tma_atoms_and_tensors(
|
|
q,
|
|
q_smem_layout_staged,
|
|
(self.qk_mma_tiler[0], self.qk_mma_tiler[2]),
|
|
self.cluster_shape_mnk[1],
|
|
)
|
|
|
|
# TMA load for K
|
|
k_smem_layout = cute.slice_(k_smem_layout_staged, (None, None, 0))
|
|
tma_atom_k, tma_tensor_k = self._make_tma_atoms_and_tensors(
|
|
k,
|
|
k_smem_layout_staged,
|
|
(self.qk_mma_tiler[1], self.qk_mma_tiler[2]),
|
|
self.cluster_shape_mnk[0],
|
|
)
|
|
|
|
# TMA load for V
|
|
pv_tile_shape_mnk = (
|
|
self.qk_mma_tiler[0],
|
|
self.qk_mma_tiler[2],
|
|
self.qk_mma_tiler[1],
|
|
)
|
|
tma_atom_v, tma_tensor_v = self._make_tma_atoms_and_tensors(
|
|
v,
|
|
v_smem_layout_staged,
|
|
(pv_tile_shape_mnk[1], pv_tile_shape_mnk[2]),
|
|
self.cluster_shape_mnk[0],
|
|
)
|
|
|
|
o_cta_v_layout = cute.composition(
|
|
cute.make_identity_layout(o.shape), self.epi_tile
|
|
)
|
|
o_smem_layout = cute.slice_(o_smem_layout_staged, (None, None, 0, 0))
|
|
|
|
tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp()
|
|
tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
|
tma_store_op,
|
|
o,
|
|
o_smem_layout,
|
|
self.epi_tile,
|
|
)
|
|
|
|
q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout)
|
|
k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout)
|
|
self.tma_copy_q_bytes = q_copy_size
|
|
self.tma_copy_kv_bytes = k_copy_size
|
|
|
|
self.tile_sched_params, grid = fmha_utils.compute_grid(
|
|
o.shape,
|
|
self.cta_tiler,
|
|
self.is_persistent,
|
|
)
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
# 2 for full/empty
|
|
load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2]
|
|
load_kv_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2]
|
|
MathWarpGroupOrderBarrier: cute.struct.MemRange[
|
|
cutlass.Int64, self.num_mma_warp_groups
|
|
]
|
|
|
|
sO: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.o_dtype,
|
|
(
|
|
cute.cosize(o_smem_layout_staged)
|
|
if cutlass.const_expr(self.is_persistent)
|
|
else 0
|
|
),
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
sQ: cute.struct.Align[
|
|
cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)],
|
|
self.buffer_align_bytes,
|
|
]
|
|
sK: cute.struct.Align[
|
|
cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)],
|
|
self.buffer_align_bytes,
|
|
]
|
|
|
|
self.shared_storage = SharedStorage
|
|
|
|
# Launch the kernel synchronously
|
|
self.kernel(
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
tma_atom_q,
|
|
tma_tensor_q,
|
|
tma_atom_k,
|
|
tma_tensor_k,
|
|
tma_atom_v,
|
|
tma_tensor_v,
|
|
tma_atom_o,
|
|
tma_tensor_o,
|
|
lse,
|
|
scale_softmax_log2,
|
|
scale_softmax,
|
|
scale_output,
|
|
window_size_left,
|
|
window_size_right,
|
|
q_smem_layout_staged,
|
|
k_smem_layout_staged,
|
|
v_smem_layout_staged,
|
|
o_smem_layout_staged,
|
|
self.tile_sched_params,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=self.cluster_shape_mnk,
|
|
smem=self.shared_storage.size_in_bytes(),
|
|
stream=stream,
|
|
min_blocks_per_mp=1,
|
|
)
|
|
|
|
# GPU device kernel
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
qk_tiled_mma: cute.TiledMma,
|
|
pv_tiled_mma: cute.TiledMma,
|
|
tma_atom_q: cute.CopyAtom,
|
|
mQ_qdl: cute.Tensor,
|
|
tma_atom_k: cute.CopyAtom,
|
|
mK_kdl: cute.Tensor,
|
|
tma_atom_v: cute.CopyAtom,
|
|
mV_dkl: cute.Tensor,
|
|
tma_atom_o: cute.CopyAtom,
|
|
mO_qdl: cute.Tensor,
|
|
mLse_qdl: cute.Tensor,
|
|
scale_softmax_log2: cutlass.Float32,
|
|
scale_softmax: cutlass.Float32,
|
|
scale_output: cutlass.Float32,
|
|
window_size_left: Optional[cutlass.Int32],
|
|
window_size_right: Optional[cutlass.Int32],
|
|
q_smem_layout_staged: cute.ComposedLayout,
|
|
k_smem_layout_staged: cute.ComposedLayout,
|
|
v_smem_layout_staged: cute.ComposedLayout,
|
|
o_smem_layout_staged: cute.ComposedLayout,
|
|
tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams,
|
|
):
|
|
"""The device kernel implementation of the Fused Multi-Head Attention for Hopper architecture.
|
|
|
|
This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation:
|
|
1. Load warp group: Loads Q, K, V data from global memory to shared memory using TMA
|
|
2. Comput warps groups: Performs matrix multiplications (Q*K^T and P*V) using Hopper TensorCores,
|
|
then compute softmax normalization on attention scores with numerical stability.
|
|
Handle final output transformation and storage.
|
|
|
|
The kernel implements a complex pipeline with overlapping computation and memory operations,
|
|
using tensor memory access (TMA) for efficient data loading, warp specialization for different
|
|
computation phases, and optional attention masking for causal or residual attention patterns.
|
|
|
|
Key optimizations include:
|
|
- Warp group specialization for load, compute/epilogue phases
|
|
- Pipeline stages between different warps for overlapping computation and memory access
|
|
- Efficient shared memory layouts optimized for Hopper architecture
|
|
- Support for different precision data types and accumulation types
|
|
- Optional causal masking for autoregressive models
|
|
- Sliding window attention masking for efficient long sequence processing
|
|
|
|
:param qk_tiled_mma: Tiled MMA for Q*K^T matrix multiplication
|
|
:type qk_tiled_mma: cute.TiledMma
|
|
:param pv_tiled_mma: Tiled MMA for P*V matrix multiplication
|
|
:type pv_tiled_mma: cute.TiledMma
|
|
:param tma_atom_q: TMA copy atom for query tensor loading
|
|
:type tma_atom_q: cute.CopyAtom
|
|
:param mQ_qdl: Partitioned query tensor for TMA loading
|
|
:type mQ_qdl: cute.Tensor
|
|
:param tma_atom_k: TMA copy atom for key tensor loading
|
|
:type tma_atom_k: cute.CopyAtom
|
|
:param mK_kdl: Partitioned key tensor for TMA loading
|
|
:type mK_kdl: cute.Tensor
|
|
:param tma_atom_v: TMA copy atom for value tensor loading
|
|
:type tma_atom_v: cute.CopyAtom
|
|
:param mV_dkl: Partitioned value tensor for TMA loading
|
|
:type mV_dkl: cute.Tensor
|
|
:param tma_atom_o: TMA copy atom for output tensor storage
|
|
:type tma_atom_o: cute.CopyAtom
|
|
:param mO_qdl: Partitioned output tensor for TMA storage
|
|
:type mO_qdl: cute.Tensor
|
|
:param mLse_qdl: Tensor for lse
|
|
:type mLse_qdl: cute.Tensor
|
|
:param scale_softmax_log2: The log2 scale factor for softmax computation
|
|
:type scale_softmax_log2: cutlass.Float32
|
|
:param scale_softmax: The scale factor for softmax (currently unused)
|
|
:type scale_softmax: cutlass.Float32
|
|
:param scale_output: The scale factor for the final output
|
|
:type scale_output: cutlass.Float32
|
|
:param window_size_left: Left-side sliding window size for attention masking
|
|
:type window_size_left: Optional[cutlass.Int32]
|
|
:param window_size_right: Right-side sliding window size for attention masking
|
|
:type window_size_right: Optional[cutlass.Int32]
|
|
:param q_smem_layout_staged: Shared memory layout for query tensor with staging
|
|
:type q_smem_layout_staged: cute.ComposedLayout
|
|
:param k_smem_layout_staged: Shared memory layout for key tensor with staging
|
|
:type k_smem_layout_staged: cute.ComposedLayout
|
|
:param v_smem_layout_staged: Shared memory layout for value tensor with staging
|
|
:type v_smem_layout_staged: cute.ComposedLayout
|
|
:param o_smem_layout_staged: Shared memory layout for output tensor with staging
|
|
:type o_smem_layout_staged: cute.ComposedLayout
|
|
:param tile_sched_params: Scheduling parameters for work distribution across blocks
|
|
:type tile_sched_params: fmha_utils.FmhaStaticTileSchedulerParams
|
|
"""
|
|
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
|
|
# Alloc
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(self.shared_storage)
|
|
|
|
load_q_producer, load_q_consumer = self.make_and_init_load_q_pipeline(
|
|
storage.load_q_mbar_ptr.data_ptr()
|
|
)
|
|
load_kv_producer, load_kv_consumer = self.make_and_init_load_kv_pipeline(
|
|
storage.load_kv_mbar_ptr.data_ptr()
|
|
)
|
|
tma_store_pipeline = self.make_and_init_tma_store_pipeline()
|
|
|
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
warp_group_idx = cute.arch.make_warp_uniform(
|
|
tidx // self.num_threads_per_warp_group
|
|
)
|
|
|
|
math_wg_order_barrier = self.make_and_init_order_barrier(
|
|
storage.MathWarpGroupOrderBarrier.data_ptr(),
|
|
warp_group_idx - 1,
|
|
)
|
|
|
|
# Generate smem tensor Q/K/V/O
|
|
# (MMA, MMA_Q, MMA_D, PIPE)
|
|
sQ = storage.sQ.get_tensor(
|
|
q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_K, MMA_D, PIPE)
|
|
sK = storage.sK.get_tensor(
|
|
k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_K, MMA_D, PIPE)
|
|
sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner)
|
|
sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer)
|
|
|
|
if cutlass.const_expr(self.is_persistent):
|
|
sO = storage.sO.get_tensor(
|
|
o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner
|
|
)
|
|
else:
|
|
sO = cute.make_tensor(
|
|
cute.recast_ptr(sQ.iterator, o_smem_layout_staged.inner, self.o_dtype),
|
|
o_smem_layout_staged.outer,
|
|
)
|
|
|
|
seqlen_q = mQ_qdl.shape[0]
|
|
gQ_qdl = cute.flat_divide(mQ_qdl, cute.select(self.qk_mma_tiler, mode=[0, 2]))
|
|
qk_thr_mma = qk_tiled_mma.get_slice(tidx)
|
|
tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl)
|
|
|
|
tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_q,
|
|
0, # no multicast
|
|
cute.make_layout(1),
|
|
cute.group_modes(sQ, 0, 2),
|
|
cute.group_modes(tSgQ_qdl, 0, 3),
|
|
)
|
|
|
|
seqlen_k = mK_kdl.shape[0]
|
|
gK_kdl = cute.flat_divide(mK_kdl, cute.select(self.qk_mma_tiler, mode=[1, 2]))
|
|
tSgK_kdl = qk_thr_mma.partition_B(gK_kdl)
|
|
tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_k,
|
|
0, # no multicast
|
|
cute.make_layout(1),
|
|
cute.group_modes(sK, 0, 2),
|
|
cute.group_modes(tSgK_kdl, 0, 3),
|
|
)
|
|
|
|
gV_dkl = cute.flat_divide(mV_dkl, cute.select(self.pv_mma_tiler, mode=[1, 2]))
|
|
pv_thr_mma = pv_tiled_mma.get_slice(tidx)
|
|
tSgV_dkl = pv_thr_mma.partition_B(gV_dkl)
|
|
tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_v,
|
|
0, # no multicast
|
|
cute.make_layout(1),
|
|
cute.group_modes(sV, 0, 2),
|
|
cute.group_modes(tSgV_dkl, 0, 3),
|
|
)
|
|
|
|
producer_warp_role = warp_idx % 4 # self.num_warps_per_warp_group
|
|
|
|
# We need this to guarantee that the Pipeline init is visible
|
|
# To all producers and consumer blocks in the Cluster
|
|
# and to finish smem init
|
|
if cute.size(self.cluster_shape_mnk) > 1:
|
|
cute.arch.cluster_arrive_relaxed()
|
|
cute.arch.cluster_wait()
|
|
else:
|
|
cute.arch.sync_threads()
|
|
|
|
if warp_idx == 0:
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q)
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k)
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v)
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o)
|
|
|
|
if warp_group_idx == self.load_warp_group_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
|
|
|
|
tile_sched = fmha_utils.create_fmha_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
while work_tile.is_valid_tile:
|
|
curr_block_coord = work_tile.tile_idx
|
|
|
|
q0_index = 0
|
|
k_index = fmha_utils.FusedMask.get_trip_start(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
)
|
|
fusion_tile_count = fmha_utils.FusedMask.get_trip_count(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
q_tile_count = self.num_mma_warp_groups
|
|
k_tile_count = 2 * fusion_tile_count
|
|
|
|
curr_block_coord_m = curr_block_coord[0]
|
|
_tQgQ = tQgQ_qdl[(None, None, 0, curr_block_coord[2])]
|
|
tQgQ = cute.domain_offset(
|
|
(0, curr_block_coord_m * self.num_mma_warp_groups), _tQgQ
|
|
)
|
|
|
|
if producer_warp_role == self.producer_warp_loadkv_id:
|
|
# LoadQ
|
|
if q_tile_count > 0:
|
|
q_handle = load_q_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_q,
|
|
tQgQ[(None, q0_index)],
|
|
tQsQ[(None, q_handle.index)],
|
|
tma_bar_ptr=q_handle.barrier,
|
|
)
|
|
q0_index += 1
|
|
|
|
q_tile_count -= 1
|
|
|
|
tKgK = tKgK_kdl[(None, None, 0, curr_block_coord[2])]
|
|
tVgV = tVgV_dkl[(None, 0, None, curr_block_coord[2])]
|
|
|
|
# Load K
|
|
if k_tile_count > 0:
|
|
k_handle = load_kv_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_k,
|
|
tKgK[(None, k_index)],
|
|
tKsK[(None, k_handle.index)],
|
|
tma_bar_ptr=k_handle.barrier,
|
|
)
|
|
|
|
k_tile_count -= 1
|
|
|
|
# Q1
|
|
if q_tile_count > 0:
|
|
q_handle = load_q_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_q,
|
|
tQgQ[(None, q0_index)],
|
|
tQsQ[(None, q_handle.index)],
|
|
tma_bar_ptr=q_handle.barrier,
|
|
)
|
|
q0_index += 1
|
|
q_tile_count -= 1
|
|
|
|
# LoadV
|
|
if k_tile_count > 0:
|
|
k_handle = load_kv_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_v,
|
|
tVgV[(None, k_index)],
|
|
tVsV[(None, k_handle.index)],
|
|
tma_bar_ptr=k_handle.barrier,
|
|
)
|
|
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
while k_tile_count > 0:
|
|
# Load KV
|
|
k_handle = load_kv_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_k,
|
|
tKgK[(None, k_index)],
|
|
tKsK[(None, k_handle.index)],
|
|
tma_bar_ptr=k_handle.barrier,
|
|
)
|
|
|
|
k_tile_count -= 1
|
|
|
|
v_handle = load_kv_producer.acquire_and_advance()
|
|
cute.copy(
|
|
tma_atom_v,
|
|
tVgV[(None, k_index)],
|
|
tVsV[(None, v_handle.index)],
|
|
tma_bar_ptr=v_handle.barrier,
|
|
)
|
|
|
|
k_index += 1
|
|
k_tile_count -= 1
|
|
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
# Mainloop
|
|
if (
|
|
warp_group_idx == self.compute_epilogue_0_warp_group_id
|
|
or warp_group_idx == self.compute_epilogue_1_warp_group_id
|
|
):
|
|
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
|
|
|
|
tile_sched = fmha_utils.create_fmha_static_tile_scheduler(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
kOuterLoads = 1
|
|
|
|
cP = cute.make_identity_tensor((mQ_qdl.shape[0], seqlen_k))
|
|
gPcP = cute.local_tile(cP, self.qk_mma_tiler[:2], (None, None))
|
|
|
|
while work_tile.is_valid_tile:
|
|
for i in cutlass.range((warp_group_idx - 1) * kOuterLoads, unroll=1):
|
|
load_q_consumer.advance()
|
|
|
|
curr_block_coord = work_tile.tile_idx
|
|
|
|
# _wg_coord_1 is work_tile.tile_idx[1], which is always 0.
|
|
_wg_coord_0 = self.num_mma_warp_groups * curr_block_coord[0] + (
|
|
warp_group_idx - 1
|
|
)
|
|
_wg_coord_1 = curr_block_coord[1]
|
|
|
|
wg_coord = (_wg_coord_0, _wg_coord_1, *curr_block_coord[2:])
|
|
|
|
# Mainloop setup QK
|
|
tSsQ = qk_thr_mma.partition_A(sQ) # (MMA,MMA_M,MMA_K,PIPE)
|
|
tSsK = qk_thr_mma.partition_B(sK) # (MMA,MMA_N,MMA_K,PIPE)
|
|
tSrQ = qk_thr_mma.make_fragment_A(tSsQ) # (MMA,MMA_M,MMA_K,PIPE)
|
|
tSrK = qk_thr_mma.make_fragment_B(tSsK) # (MMA,MMA_N,MMA_K,PIPE)
|
|
|
|
# Prepare: MMA PV
|
|
thr_mma_pv = pv_tiled_mma.get_slice(tidx)
|
|
|
|
# Mainloop setup PV
|
|
tOsV = thr_mma_pv.partition_B(sV) # (MMA,MMA_N,MMA_K,PIPE)
|
|
tOrV = thr_mma_pv.make_fragment_B(tOsV) # (MMA,MMA_M,MMA_N,PIPE)
|
|
|
|
q_handle = load_q_consumer.wait()
|
|
|
|
# mapping into QK accumulator
|
|
ptPcP = qk_thr_mma.partition_C(gPcP)
|
|
|
|
# Allocate PV acc
|
|
pv_acc_shape = pv_thr_mma.partition_shape_C(
|
|
(self.pv_mma_tiler[0], self.pv_mma_tiler[1])
|
|
)
|
|
acc_pv = pv_thr_mma.make_fragment_C(pv_acc_shape)
|
|
|
|
qk_acc_shape = qk_thr_mma.partition_shape_C(
|
|
(self.qk_mma_tiler[0], self.qk_mma_tiler[1])
|
|
)
|
|
|
|
s_max_layout = cute.make_layout(
|
|
cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
|
|
)
|
|
s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype)
|
|
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
|
|
|
|
kv_offset = fmha_utils.FusedMask.get_trip_start(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
)
|
|
|
|
masked_leading_count = fmha_utils.FusedMask.get_masked_leading_count(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
unmasked_trip_count = fmha_utils.FusedMask.get_unmasked_trip_count(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
# mapping into QK accumulator
|
|
tPcP = cute.slice_(ptPcP, (None, None, None, wg_coord[0], kv_offset))
|
|
kv_offset += 1
|
|
|
|
qk_acc_shape = qk_thr_mma.partition_shape_C(
|
|
(self.qk_mma_tiler[0], self.qk_mma_tiler[1])
|
|
)
|
|
|
|
# Allocate QK acc
|
|
acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape)
|
|
k_handle = load_kv_consumer.wait_and_advance()
|
|
math_wg_order_barrier.wait()
|
|
|
|
# MMA QK
|
|
cute.nvgpu.warpgroup.fence()
|
|
|
|
gemm_zero_acc(
|
|
qk_tiled_mma,
|
|
tSrQ[(None, None, None, q_handle.index)],
|
|
tSrK[(None, None, None, k_handle.index)],
|
|
acc_qk,
|
|
)
|
|
cute.nvgpu.warpgroup.commit_group()
|
|
|
|
math_wg_order_barrier.arrive()
|
|
|
|
# Wait for the pipeline MMAs to drain
|
|
cute.nvgpu.warpgroup.wait_group(0)
|
|
|
|
s_max, a_sum = softmax_step(
|
|
True,
|
|
self.mask_type,
|
|
acc_qk,
|
|
qk_tiled_mma,
|
|
tPcP,
|
|
s_max,
|
|
a_sum,
|
|
acc_qk,
|
|
qk_tiled_mma,
|
|
scale_softmax_log2,
|
|
seqlen_k,
|
|
seqlen_q,
|
|
window_size_left,
|
|
window_size_right,
|
|
True,
|
|
)
|
|
|
|
acc_qk_fixed = make_acc_into_op(
|
|
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
|
)
|
|
|
|
v_handle = load_kv_consumer.wait_and_advance()
|
|
|
|
# MMA PV
|
|
cute.nvgpu.warpgroup.fence()
|
|
|
|
gemm_zero_acc(
|
|
pv_tiled_mma,
|
|
acc_qk_fixed,
|
|
tOrV[(None, None, None, v_handle.index)],
|
|
acc_pv,
|
|
)
|
|
cute.nvgpu.warpgroup.commit_group()
|
|
cute.nvgpu.warpgroup.wait_group(0)
|
|
|
|
k_handle.release()
|
|
v_handle.release()
|
|
|
|
if masked_leading_count >= 1:
|
|
masked_leading_count -= 1
|
|
load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = (
|
|
self.compute(
|
|
True,
|
|
masked_leading_count,
|
|
qk_thr_mma,
|
|
acc_pv,
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
load_kv_consumer,
|
|
q_handle,
|
|
tSrQ,
|
|
tSrK,
|
|
s_max,
|
|
a_sum,
|
|
tOrV,
|
|
ptPcP,
|
|
wg_coord,
|
|
kv_offset,
|
|
scale_softmax_log2,
|
|
seqlen_k,
|
|
seqlen_q,
|
|
qk_acc_shape,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
)
|
|
else:
|
|
unmasked_trip_count -= 1
|
|
|
|
load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = self.compute(
|
|
False,
|
|
unmasked_trip_count,
|
|
qk_thr_mma,
|
|
acc_pv,
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
load_kv_consumer,
|
|
q_handle,
|
|
tSrQ,
|
|
tSrK,
|
|
s_max,
|
|
a_sum,
|
|
tOrV,
|
|
ptPcP,
|
|
wg_coord,
|
|
kv_offset,
|
|
scale_softmax_log2,
|
|
seqlen_k,
|
|
seqlen_q,
|
|
qk_acc_shape,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
k_tile_count = fmha_utils.FusedMask.get_masked_trailing_count(
|
|
self.mask_type,
|
|
curr_block_coord,
|
|
self.cta_tiler,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
k_tile_count,
|
|
)
|
|
|
|
# Use fusion in softmax
|
|
load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum = self.compute(
|
|
True,
|
|
k_tile_count,
|
|
qk_thr_mma,
|
|
acc_pv,
|
|
qk_tiled_mma,
|
|
pv_tiled_mma,
|
|
load_kv_consumer,
|
|
q_handle,
|
|
tSrQ,
|
|
tSrK,
|
|
s_max,
|
|
a_sum,
|
|
tOrV,
|
|
ptPcP,
|
|
wg_coord,
|
|
kv_offset,
|
|
scale_softmax_log2,
|
|
seqlen_k,
|
|
seqlen_q,
|
|
qk_acc_shape,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
if cutlass.const_expr(self.is_persistent):
|
|
q_handle.release()
|
|
|
|
# Wait for the pipeline MMAs to drain
|
|
cute.nvgpu.warpgroup.wait_group(0)
|
|
|
|
# acc_pv updated
|
|
lse = tail(
|
|
s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output
|
|
)
|
|
|
|
if warp_group_idx == self.compute_epilogue_0_warp_group_id:
|
|
for i in cutlass.range_constexpr(
|
|
kOuterLoads * (self.num_mma_warp_groups - 0)
|
|
):
|
|
load_q_consumer.advance()
|
|
|
|
if cutlass.const_expr(self.num_mma_warp_groups >= 2):
|
|
if warp_group_idx == self.compute_epilogue_1_warp_group_id:
|
|
for i in cutlass.range_constexpr(
|
|
kOuterLoads * (self.num_mma_warp_groups - 1)
|
|
):
|
|
load_q_consumer.advance()
|
|
|
|
math_wg_order_barrier.wait()
|
|
|
|
# store log-sum-exp (LSE)
|
|
thr_mma = pv_tiled_mma.get_slice(tidx)
|
|
|
|
gLSE_full = cute.local_tile(
|
|
mLse_qdl, self.pv_mma_tiler[:2], (None, None, None)
|
|
)
|
|
gLSE = cute.slice_(
|
|
gLSE_full, (None, None, wg_coord[0], wg_coord[1], wg_coord[2])
|
|
)
|
|
|
|
tOgLSE = thr_mma.partition_C(gLSE)
|
|
cO = cute.make_identity_tensor(
|
|
(self.pv_mma_tiler[0], self.pv_mma_tiler[1])
|
|
)
|
|
tOcO = thr_mma.partition_C(cO)
|
|
|
|
if tOcO[0][1] == 0:
|
|
tOgLSE_mn = cute.make_tensor(
|
|
tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
|
|
)
|
|
tOcO_mn = cute.make_tensor(
|
|
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout)
|
|
)
|
|
for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])):
|
|
if (
|
|
tOcO_mn[i][0] + wg_coord[0] * self.pv_mma_tiler[0]
|
|
< seqlen_q
|
|
):
|
|
tOgLSE_mn[(i, 0)] = lse[i]
|
|
|
|
# Epilogue
|
|
cO = cute.make_identity_tensor((self.cta_tiler[0], self.cta_tiler[2]))
|
|
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
self.o_layout,
|
|
elem_ty_d=self.o_dtype,
|
|
elem_ty_acc=self.pv_acc_dtype,
|
|
)
|
|
|
|
copy_atom_O = cute.make_copy_atom(
|
|
cute.nvgpu.warp.StMatrix8x8x16bOp(
|
|
self.o_layout.is_m_major_c(),
|
|
4,
|
|
),
|
|
self.o_dtype,
|
|
)
|
|
|
|
tiled_copy_O_Atom = cute.make_tiled_copy_C_atom(
|
|
copy_atom_O, pv_tiled_mma
|
|
)
|
|
|
|
tiled_copy_r2s = cute.make_tiled_copy_S(
|
|
copy_atom_r2s,
|
|
tiled_copy_O_Atom,
|
|
)
|
|
|
|
thr_copy_r2s = tiled_copy_r2s.get_slice(
|
|
tidx % self.num_threads_per_warp_group
|
|
)
|
|
tRS_sD = thr_copy_r2s.partition_D(sO)
|
|
tRS_rAcc = tiled_copy_r2s.retile(acc_pv)
|
|
|
|
# Allocate D registers.
|
|
rD_shape = cute.shape(thr_copy_r2s.partition_S(sO))
|
|
tRS_rD_layout = cute.make_layout(rD_shape[:3])
|
|
|
|
tRS_rD = cute.make_rmem_tensor_like(tRS_rD_layout, self.pv_acc_dtype)
|
|
size_tRS_rD = cute.size(tRS_rD)
|
|
|
|
gD = cute.local_tile(
|
|
mO_qdl,
|
|
self.pv_mma_tiler[:2],
|
|
(wg_coord[0], 0, wg_coord[2]),
|
|
)
|
|
|
|
sepi_for_tma_partition = cute.group_modes(sO, 0, 2)
|
|
tcgc_for_tma_partition = cute.zipped_divide(gD, self.epi_tile)
|
|
|
|
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_o,
|
|
0,
|
|
cute.make_layout(1),
|
|
sepi_for_tma_partition,
|
|
tcgc_for_tma_partition,
|
|
)
|
|
|
|
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
|
|
|
|
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
# Copy from accumulators to D registers
|
|
for epi_v in cutlass.range_constexpr(size_tRS_rD):
|
|
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
|
|
|
|
# Type conversion
|
|
tRS_rD_out = cute.make_rmem_tensor_like(tRS_rD_layout, self.o_dtype)
|
|
acc_vec = tRS_rD.load()
|
|
tRS_rD_out.store(acc_vec.to(self.o_dtype))
|
|
|
|
# Copy from D registers to shared memory
|
|
epi_buffer = epi_idx % self.epi_stage
|
|
cute.copy(
|
|
tiled_copy_r2s,
|
|
tRS_rD_out,
|
|
tRS_sD[(None, None, None, epi_buffer, warp_group_idx - 1)],
|
|
)
|
|
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
pipeline.arrive_and_wait(
|
|
barrier_id=warp_group_idx,
|
|
num_threads=self.num_threads_per_warp_group,
|
|
)
|
|
|
|
# only one warp in each warpgroup copy shared memory to global memory
|
|
if warp_idx == 4 or warp_idx == 8:
|
|
cute.copy(
|
|
tma_atom_o,
|
|
bSG_sD[(None, epi_buffer, warp_group_idx - 1)],
|
|
bSG_gD[(None, epi_idx)],
|
|
)
|
|
|
|
tma_store_pipeline.producer_commit()
|
|
tma_store_pipeline.producer_acquire()
|
|
|
|
pipeline.arrive_and_wait(
|
|
barrier_id=warp_group_idx,
|
|
num_threads=self.num_threads_per_warp_group,
|
|
)
|
|
|
|
math_wg_order_barrier.arrive()
|
|
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
return
|
|
|
|
@cute.jit
|
|
def compute(
|
|
self,
|
|
fusion: bool,
|
|
k_tile_count: cutlass.Int32,
|
|
qk_thr_mma: cute.ThrMma,
|
|
acc_pv: cute.ThrMma,
|
|
qk_tiled_mma: cute.TiledMma,
|
|
pv_tiled_mma: cute.TiledMma,
|
|
load_kv_consumer: pipeline.PipelineConsumer,
|
|
q_handle: pipeline.PipelineConsumer.ImmutableResourceHandle,
|
|
tSrQ: cute.Tensor,
|
|
tSrK: cute.Tensor,
|
|
s_max: cute.Tensor,
|
|
a_sum: cute.Tensor,
|
|
tOrV: cute.Tensor,
|
|
ptPcP: cute.Tensor,
|
|
wg_coord: tuple,
|
|
kv_offset: cutlass.Int32,
|
|
scale_softmax_log2: cutlass.Float32,
|
|
seqlen_k: cutlass.Int32,
|
|
seqlen_q: cutlass.Int32,
|
|
qk_acc_shape: cute.Shape,
|
|
window_size_left: Optional[cutlass.Int32],
|
|
window_size_right: Optional[cutlass.Int32],
|
|
) -> Tuple[
|
|
pipeline.PipelineConsumer,
|
|
cutlass.Int32,
|
|
cutlass.Int32,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
]:
|
|
while k_tile_count > 0:
|
|
k_tile_count -= 1
|
|
|
|
tPcP = cute.slice_(ptPcP, (None, None, None, wg_coord[0], kv_offset))
|
|
kv_offset += 1
|
|
|
|
# Allocate QK acc
|
|
acc_qk = qk_thr_mma.make_fragment_C(qk_acc_shape)
|
|
|
|
k_handle = load_kv_consumer.wait_and_advance()
|
|
|
|
# MMA QK
|
|
cute.nvgpu.warpgroup.fence()
|
|
|
|
gemm_zero_acc(
|
|
qk_tiled_mma,
|
|
tSrQ[(None, None, None, q_handle.index)],
|
|
tSrK[(None, None, None, k_handle.index)],
|
|
acc_qk,
|
|
)
|
|
|
|
cute.nvgpu.warpgroup.commit_group()
|
|
|
|
tok = load_kv_consumer.try_wait()
|
|
|
|
# Wait for the pipeline MMAs to drain
|
|
cute.nvgpu.warpgroup.wait_group(0)
|
|
|
|
s_max, a_sum = softmax_step(
|
|
fusion,
|
|
self.mask_type,
|
|
acc_qk,
|
|
qk_tiled_mma,
|
|
tPcP,
|
|
s_max,
|
|
a_sum,
|
|
acc_pv,
|
|
pv_tiled_mma,
|
|
scale_softmax_log2,
|
|
seqlen_k,
|
|
seqlen_q,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
acc_qk_fixed = make_acc_into_op(
|
|
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
|
)
|
|
|
|
v_handle = load_kv_consumer.wait_and_advance(tok)
|
|
|
|
# MMA PV
|
|
cute.nvgpu.warpgroup.fence()
|
|
|
|
pv_tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
|
|
cute.gemm(
|
|
pv_tiled_mma,
|
|
acc_pv,
|
|
acc_qk_fixed,
|
|
tOrV[(None, None, None, v_handle.index)],
|
|
acc_pv,
|
|
)
|
|
|
|
cute.nvgpu.warpgroup.commit_group()
|
|
cute.nvgpu.warpgroup.wait_group(0)
|
|
|
|
k_handle.release()
|
|
v_handle.release()
|
|
|
|
return load_kv_consumer, k_tile_count, kv_offset, s_max, a_sum
|
|
|
|
@cute.jit
|
|
def softmax_step(
|
|
fusion: bool,
|
|
mask_type: fmha_utils.MaskEnum,
|
|
acc_qk: cute.ThrMma,
|
|
tiled_mma_qk: cute.TiledMma,
|
|
count_qk: cute.Tensor,
|
|
s_max: cute.Tensor,
|
|
a_sum: cute.Tensor,
|
|
acc_pv: cute.ThrMma,
|
|
tiled_mma_pv: cute.TiledMma,
|
|
scale_softmax_log2: cutlass.Float32,
|
|
seqlen_k: cutlass.Int32,
|
|
seqlen_q: cutlass.Int32,
|
|
window_size_left: Optional[cutlass.Int32],
|
|
window_size_right: Optional[cutlass.Int32],
|
|
is_first_iter: bool = False,
|
|
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
if cutlass.const_expr(fusion):
|
|
fmha_utils.FusedMask.apply_mask(
|
|
mask_type,
|
|
acc_qk,
|
|
count_qk,
|
|
seqlen_q,
|
|
seqlen_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
acc_qk_mn = cute.make_tensor(
|
|
acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout)
|
|
)
|
|
|
|
reduction_target_qk = reduction_target_n(tiled_mma_qk)
|
|
red_rank = cute.rank(reduction_target_qk)
|
|
|
|
s_max_prev = None
|
|
acc_pv_mn = None
|
|
if cutlass.const_expr(is_first_iter):
|
|
# Linear reduction is faster for the first iteration
|
|
for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])):
|
|
s_max[i] = acc_qk_mn[i, 0]
|
|
|
|
for j in cutlass.range_constexpr(1, cute.size(acc_qk_mn, mode=[1])):
|
|
for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])):
|
|
s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j])
|
|
else:
|
|
acc_pv_mn = cute.make_tensor(
|
|
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
|
)
|
|
s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype)
|
|
|
|
for i in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[0])):
|
|
if cutlass.const_expr(not is_first_iter):
|
|
s_max_prev[i] = s_max[i]
|
|
|
|
# Linear reduction is faster here, as well
|
|
for j in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[1])):
|
|
s_max[i] = cutlass.max(s_max[i], acc_qk_mn[i, j])
|
|
|
|
# reduce max
|
|
for r in cutlass.range_constexpr(red_rank):
|
|
s_max[i] = cute.arch.warp_reduction_max(
|
|
s_max[i], threads_in_group=reduction_target_qk.shape[r]
|
|
)
|
|
|
|
local_max = s_max[i]
|
|
if s_max[i] == -cutlass.Float32.inf:
|
|
local_max = 0.0
|
|
scale_max = scale_softmax_log2 * local_max
|
|
|
|
for j in cutlass.range_constexpr(cute.size(acc_qk_mn, mode=[1])):
|
|
acc_qk_mn[i, j] = cute.math.exp2(
|
|
scale_softmax_log2 * acc_qk_mn[i, j] - scale_max, fastmath=True
|
|
)
|
|
|
|
_a_sum = 0.0
|
|
if cutlass.const_expr(not is_first_iter):
|
|
s_max_cur = s_max[i]
|
|
if s_max[i] == -cutlass.Float32.inf:
|
|
s_max_cur = 0.0
|
|
scale_pv = cute.math.exp2(
|
|
(s_max_prev[i] - s_max_cur) * scale_softmax_log2, fastmath=True
|
|
)
|
|
a_sum[i] *= scale_pv
|
|
|
|
for j in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[1])):
|
|
acc_pv_mn[i, j] *= scale_pv
|
|
|
|
_a_sum = a_sum[i]
|
|
|
|
a_sum[i] = _a_sum + acc_qk_mn[i, None].load().reduce(
|
|
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
|
|
)
|
|
|
|
return s_max, a_sum
|
|
|
|
@cute.jit
|
|
def reduction_target_n(tiled_mma):
|
|
separated = layout_separate(
|
|
tiled_mma.shape_mnk[0],
|
|
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
|
|
tiled_mma.tv_layout_C.stride[0],
|
|
)
|
|
return separated[1]
|
|
|
|
@cute.jit
|
|
def convert_c_layout_to_a_layout(c, a):
|
|
return cute.make_layout(
|
|
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
|
|
stride=(
|
|
c.stride[0],
|
|
c.stride[1],
|
|
(c.stride[2], cute.size(a, mode=[2]) * c.stride[0][2]),
|
|
),
|
|
)
|
|
|
|
@cute.jit
|
|
def make_acc_into_op(acc, operand_layout_tv, Element):
|
|
operand = cute.make_rmem_tensor_like(
|
|
convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
|
|
Element,
|
|
)
|
|
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
|
|
acc_vec = acc.load()
|
|
operand_as_acc.store(acc_vec.to(Element))
|
|
|
|
if cutlass.const_expr(Element.width == 8 and True):
|
|
## 00 11 22 33 00 11 22 33 acc layout
|
|
## 00 00 11 11 22 22 33 33 operand layout
|
|
## BB AA AA BB AA BB BB AA conflict-free exchange pattern
|
|
## 16-bit exchange; so process two at a time potentially
|
|
# int tid = threadIdx.x % 4;
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
tid = tidx % 4
|
|
values_u32 = cute.recast_tensor(operand, cutlass.Uint32)
|
|
for n in cutlass.range_constexpr(cute.size(values_u32, mode=[1])):
|
|
for k in cutlass.range_constexpr(cute.size(values_u32, mode=[2])):
|
|
for ii in cutlass.range_constexpr(0, 8, 4):
|
|
values_tmp_0 = values_u32[ii // 2 + 0, n, k]
|
|
values_tmp_1 = values_u32[ii // 2 + 1, n, k]
|
|
|
|
## step A:
|
|
## t 1 v 0 -> t 0 v 1
|
|
## t 2 v 0 -> t 1 v 0
|
|
## t 0 v 1 -> t 2 v 0
|
|
## t 3 v 1 -> t 3 v 1
|
|
|
|
v_to_send = 1
|
|
if tid == 1 or tid == 2:
|
|
v_to_send = 0
|
|
|
|
v_to_recv = v_to_send
|
|
t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF
|
|
|
|
values_tmp_a = values_tmp_1
|
|
if v_to_send == 0:
|
|
values_tmp_a = values_tmp_0
|
|
|
|
values_tmp_a = cute.arch.shuffle_sync_op(
|
|
values_tmp_a, t_to_recv_from, 0xFFFFFFFF, 7199
|
|
)
|
|
|
|
# step B:
|
|
# t 0 v 0 -> t 0 v 0
|
|
# t 3 v 0 -> t 1 v 1
|
|
# t 1 v 1 -> t 2 v 1
|
|
# t 2 v 1 -> t 3 v 0
|
|
|
|
v_to_send = 1 - v_to_send
|
|
v_to_recv = 1 - v_to_recv
|
|
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF
|
|
|
|
values_tmp_b = values_tmp_1
|
|
if v_to_send == 0:
|
|
values_tmp_b = values_tmp_0
|
|
|
|
values_tmp_b = cute.arch.shuffle_sync_op(
|
|
values_tmp_b, t_to_recv_from, 0xFFFFFFFF, 7199
|
|
)
|
|
|
|
# __byte_perm
|
|
order = 0x5410
|
|
if v_to_send == 0:
|
|
order = 0x1054
|
|
|
|
values_u32[ii // 2 + 0, n, k] = cute.arch.prmt(
|
|
values_tmp_a,
|
|
values_tmp_b,
|
|
order,
|
|
)
|
|
|
|
order = 0x7632
|
|
if v_to_send == 0:
|
|
order = 0x3276
|
|
values_u32[ii // 2 + 1, n, k] = cute.arch.prmt(
|
|
values_tmp_a, values_tmp_b, order
|
|
)
|
|
return operand
|
|
|
|
@cute.jit
|
|
def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
|
|
"""
|
|
Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output.
|
|
|
|
This function performs the following operations:
|
|
1. Reduces the attention sums across warps using butterfly shuffle
|
|
2. Computes the log-sum-exp (LSE) for numerical stability
|
|
3. Applies softmax scaling and output scaling to the accumulated values
|
|
4. Handles edge cases like zero sums and NaN values
|
|
|
|
:param s_max: Maximum attention scores for each position (for numerical stability)
|
|
:type s_max: cute.Tensor
|
|
:param a_sum: Sum of attention scores after softmax
|
|
:type a_sum: cute.Tensor
|
|
:param acc_pv: Accumulated P*V values from the attention computation
|
|
:type acc_pv: cute.ThrMma
|
|
:param tiled_mma_pv: Tiled MMA for P*V computation
|
|
:type tiled_mma_pv: cute.TiledMma
|
|
:param scale_softmax: Scaling factor for softmax computation
|
|
:type scale_softmax: cutlass.Float32
|
|
:param scale_output: Scaling factor for final output
|
|
:type scale_output: cutlass.Float32
|
|
|
|
:return: Log-sum-exp values for each position
|
|
:rtype: cute.Tensor
|
|
"""
|
|
# Create tensor view of accumulated P*V values with M*N layout
|
|
acc_pv_mn = cute.make_tensor(
|
|
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
|
)
|
|
reduction_target = reduction_target_n(tiled_mma_pv)
|
|
red_rank = cute.rank(reduction_target)
|
|
for r in cutlass.range_constexpr(red_rank):
|
|
for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])):
|
|
a_sum[i] = cute.arch.warp_reduction_sum(
|
|
a_sum[i], threads_in_group=reduction_target.shape[r]
|
|
)
|
|
|
|
acc_mn = cute.make_tensor(
|
|
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
|
)
|
|
|
|
lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype)
|
|
for i in cutlass.range_constexpr(cute.size(acc_mn, mode=[0])):
|
|
sum = a_sum[i]
|
|
inv_sum = cute.arch.rcp_approx(sum)
|
|
if sum == 0.0 or sum != sum:
|
|
inv_sum = 1.0
|
|
|
|
lse[i] = s_max[i] * scale_softmax + _math.log(sum)
|
|
if sum == 0.0 or sum != sum:
|
|
lse[i] = cutlass.Float32.inf
|
|
|
|
rp_dropout = 1
|
|
scale = rp_dropout * inv_sum
|
|
for j in cutlass.range_constexpr(cute.size(acc_mn, mode=[1])):
|
|
acc_mn[i, j] *= scale * scale_output
|
|
|
|
return lse
|
|
|
|
@cute.jit
|
|
def layout_separate(thr, src, ref):
|
|
lt = cute.make_layout(())
|
|
ge = cute.make_layout(())
|
|
|
|
for k, v in enumerate(ref):
|
|
if cutlass.const_expr(v < thr):
|
|
lt = cute.append(lt, src[k])
|
|
else:
|
|
ge = cute.append(ge, src[k])
|
|
|
|
r = None
|
|
if cutlass.const_expr(cute.rank(lt) == 1):
|
|
r = cute.append(lt, ge)
|
|
else:
|
|
r = cute.append(cute.append(cute.make_layout(()), lt), ge)
|
|
return r
|
|
|
|
@cute.jit
|
|
def gemm_zero_acc(tiled_mma, A, B, C):
|
|
rA = cute.rank(A)
|
|
rB = cute.rank(B)
|
|
rC = cute.rank(C)
|
|
if cutlass.const_expr(rA == 2 and rB == 2 and rC == 1):
|
|
for k_block_idx in range(cute.size(A, mode=[1]), unroll_full=True):
|
|
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0)
|
|
cute.gemm(
|
|
tiled_mma,
|
|
C,
|
|
A[None, k_block_idx],
|
|
B[None, k_block_idx],
|
|
C,
|
|
)
|
|
elif cutlass.const_expr(rA == 3 and rB == 3 and rC == 3):
|
|
for k_block_idx in range(cute.size(A, mode=[2]), unroll_full=True):
|
|
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, k_block_idx != 0)
|
|
cute.gemm(
|
|
tiled_mma,
|
|
C,
|
|
A[None, None, k_block_idx],
|
|
B[None, None, k_block_idx],
|
|
C,
|
|
)
|
|
else:
|
|
assert 0
|
|
|
|
@cute.jit
|
|
def layout_acc_mn(tiled_mma, acc):
|
|
separated = layout_separate(
|
|
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
|
|
)
|
|
|
|
V_M = separated[0]
|
|
V_N = separated[1]
|
|
V_M1 = None
|
|
V_N1 = None
|
|
if cutlass.const_expr(cute.rank(V_M) == 1):
|
|
V_M1 = cute.append(V_M, acc[1])
|
|
else:
|
|
V_M1 = cute.append(cute.append(cute.make_layout(()), V_M), acc[1])
|
|
|
|
if cutlass.const_expr(cute.rank(V_N) == 1):
|
|
V_N1 = cute.append(V_N, acc[2])
|
|
else:
|
|
V_N1 = cute.append(cute.append(cute.make_layout(()), V_N), acc[2])
|
|
r = None
|
|
if cutlass.const_expr(cute.rank(V_M1) == 1):
|
|
r = cute.append(V_M1, V_N1)
|
|
else:
|
|
r = cute.append(cute.append(cute.make_layout(()), V_M1), V_N1)
|
|
return r
|
|
|
|
def make_and_init_load_q_pipeline(self, load_q_mbar_ptr):
|
|
load_q_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len([self.load_warp_group_id]),
|
|
)
|
|
load_q_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.num_warps_per_warp_group,
|
|
)
|
|
return pipeline.PipelineTmaAsync.create(
|
|
barrier_storage=load_q_mbar_ptr,
|
|
num_stages=self.q_stage,
|
|
producer_group=load_q_producer_group,
|
|
consumer_group=load_q_consumer_group,
|
|
tx_count=self.tma_copy_q_bytes,
|
|
).make_participants()
|
|
|
|
def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr):
|
|
load_kv_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
len([self.load_warp_group_id]),
|
|
)
|
|
load_kv_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.num_mma_warp_groups * self.num_warps_per_warp_group,
|
|
)
|
|
return pipeline.PipelineTmaAsync.create(
|
|
barrier_storage=load_kv_mbar_ptr,
|
|
num_stages=self.kv_stage,
|
|
producer_group=load_kv_producer_group,
|
|
consumer_group=load_kv_consumer_group,
|
|
tx_count=self.tma_copy_kv_bytes,
|
|
).make_participants()
|
|
|
|
def make_and_init_tma_store_pipeline(self):
|
|
tma_store_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
1,
|
|
)
|
|
return pipeline.PipelineTmaStore.create(
|
|
num_stages=self.epi_stage,
|
|
producer_group=tma_store_producer_group,
|
|
)
|
|
|
|
def make_and_init_order_barrier(self, order_mbar_ptr, group_id):
|
|
StagesPerMathWarpGroup = 1
|
|
return pipeline.PipelineOrder.create(
|
|
barrier_storage=order_mbar_ptr,
|
|
depth=StagesPerMathWarpGroup,
|
|
length=self.num_mma_warp_groups,
|
|
group_id=group_id,
|
|
producer_group=pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.num_threads_per_warp_group,
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _make_tma_atoms_and_tensors(
|
|
tensor: cute.Tensor,
|
|
smem_layout_staged: cute.ComposedLayout,
|
|
smem_tile: tuple[int, int],
|
|
mcast_dim: int,
|
|
) -> tuple[cute.CopyAtom, cute.Tensor]:
|
|
"""Create TMA atoms and tensors for input tensors.
|
|
|
|
:param tensor: Input tensor (A or B)
|
|
:type tensor: cute.Tensor
|
|
:param smem_layout_staged: Shared memory layout for the tensor
|
|
:type smem_layout_staged: cute.ComposedLayout
|
|
:param smem_tile: Shared memory tile shape
|
|
:type smem_tile: Tuple[int, int]
|
|
:param mcast_dim: Multicast dimension
|
|
:type mcast_dim: int
|
|
|
|
:return: TMA atom and tensor
|
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
"""
|
|
op = (
|
|
cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
|
|
if mcast_dim == 1
|
|
else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
|
|
)
|
|
|
|
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
|
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
|
op,
|
|
tensor,
|
|
smem_layout,
|
|
smem_tile,
|
|
num_multicast=mcast_dim,
|
|
)
|
|
return tma_atom, tma_tensor
|
|
|
|
@staticmethod
|
|
def can_implement(
|
|
q_shape: Tuple[int, int, int, int],
|
|
k_shape: Tuple[int, int, int, int],
|
|
in_dtype: Type[cutlass.Numeric],
|
|
out_dtype: Type[cutlass.Numeric],
|
|
qk_acc_dtype: Type[cutlass.Numeric],
|
|
pv_acc_dtype: Type[cutlass.Numeric],
|
|
mma_tiler_mn: Tuple[int, int],
|
|
is_persistent: bool,
|
|
scale_softmax: float,
|
|
window_size: Tuple[int, int],
|
|
iterations: int,
|
|
) -> Tuple[bool, str]:
|
|
"""Check if the FMHA kernel can be implemented with the given parameters.
|
|
|
|
This method validates that the input parameters are compatible with the Hopper
|
|
Fused Multi-Head Attention implementation. It checks tensor shapes, data types,
|
|
window sizes, and other constraints to ensure the kernel can be successfully
|
|
compiled and executed.
|
|
|
|
:param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length,
|
|
H=number of heads, D=head dimension
|
|
:type q_shape: Tuple[int, int, int, int]
|
|
:param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length,
|
|
H_k=number of key heads, D=head dimension
|
|
:type k_shape: Tuple[int, int, int, int]
|
|
:param in_dtype: Input data type for query, key and value tensors
|
|
:type in_dtype: Type[cutlass.Numeric]
|
|
:param out_dtype: Output data type for attention output
|
|
:type out_dtype: Type[cutlass.Numeric]
|
|
:param qk_acc_dtype: Accumulator data type for query-key matrix multiplication
|
|
:type qk_acc_dtype: Type[cutlass.Numeric]
|
|
:param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication
|
|
:type pv_acc_dtype: Type[cutlass.Numeric]
|
|
:param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N)
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param is_persistent: Whether to use persistent kernel optimization
|
|
:type is_persistent: bool
|
|
:param scale_softmax: Attention score scaling factor
|
|
:type scale_softmax: float
|
|
:param window_size: Sliding window size (left, right) for attention masking
|
|
:type window_size: Tuple[int, int]
|
|
:param iterations: Number of iterations to run for performance testing
|
|
:type iterations: int
|
|
|
|
:return: Tuple of (can_implement, error_message) where can_implement is True if the kernel
|
|
can be implemented, False otherwise, and error_message contains the reason for failure
|
|
:rtype: Tuple[bool, str]
|
|
"""
|
|
|
|
# Unpack parameters
|
|
b, s_q, h, d = q_shape
|
|
b_, s_k, h_k, d_ = k_shape
|
|
window_size_left, window_size_right = window_size
|
|
|
|
if b != b_:
|
|
return False, "q & k must have the same batch size"
|
|
|
|
if d != d_:
|
|
return False, "q & k must have the same head dimension"
|
|
|
|
if window_size_left >= s_k - 1:
|
|
return False, "window_size_left must be less than s_k_max - 1"
|
|
if window_size_right >= s_q - 1:
|
|
return False, "window_size_right must be less than s_q_max - 1"
|
|
|
|
if h % h_k != 0:
|
|
return False, "h must be divisible by h_k"
|
|
|
|
if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16, cutlass.BFloat16}:
|
|
return False, "in_dtype must be Float16, BFloat16, Float8E4M3FN"
|
|
|
|
if out_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16, cutlass.BFloat16}:
|
|
return False, "out_dtype must be Float16, BFloat16, Float8E4M3FN"
|
|
|
|
if qk_acc_dtype not in {cutlass.Float32}:
|
|
return False, "qk_acc_dtype must be Float32"
|
|
|
|
if pv_acc_dtype not in {cutlass.Float32}:
|
|
return False, "pv_acc_dtype must be Float32"
|
|
|
|
if iterations < 1:
|
|
return False, "iterations must be at least 1"
|
|
|
|
if (
|
|
in_dtype.width == 16
|
|
and out_dtype.width == 16
|
|
and (
|
|
(d_ == 256 and mma_tiler_mn[1] >= 128)
|
|
or (d_ == 128 and mma_tiler_mn[1] >= 256)
|
|
)
|
|
) or (
|
|
in_dtype.width == 8
|
|
and out_dtype.width == 8
|
|
and d_ == 256
|
|
and mma_tiler_mn[1] >= 256
|
|
):
|
|
return False, "not enough smem"
|
|
|
|
if is_persistent and (
|
|
(
|
|
in_dtype.width == 16
|
|
and out_dtype.width == 16
|
|
and (
|
|
(d_ == 128 and mma_tiler_mn[1] >= 256)
|
|
or (d_ == 256 and mma_tiler_mn[1] > 32)
|
|
)
|
|
)
|
|
or (
|
|
in_dtype.width == 8
|
|
and out_dtype.width == 8
|
|
and d_ == 256
|
|
and mma_tiler_mn[1] == 256
|
|
)
|
|
):
|
|
return False, "not supported persistent"
|
|
|
|
return True, None
|
|
|
|
|
|
def run(
|
|
q_shape: Tuple[int, int, int, int],
|
|
k_shape: Tuple[int, int, int, int],
|
|
in_dtype: Type[cutlass.Numeric],
|
|
out_dtype: Type[cutlass.Numeric],
|
|
qk_acc_dtype: Type[cutlass.Numeric],
|
|
pv_acc_dtype: Type[cutlass.Numeric],
|
|
mma_tiler_mn: Tuple[int, int],
|
|
is_persistent: bool,
|
|
is_causal: bool,
|
|
bottom_right_align: bool,
|
|
scale_q: float,
|
|
scale_k: float,
|
|
scale_v: float,
|
|
inv_scale_o: float,
|
|
scale_softmax: float,
|
|
window_size: Tuple[int, int],
|
|
tolerance: float,
|
|
warmup_iterations: int,
|
|
iterations: int,
|
|
skip_ref_check: bool,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""Execute Fused Multi-Head Attention (FMHA) on Hopper architecture and validate results.
|
|
|
|
This function creates random input tensors for query, key, and value, then performs the
|
|
complete FMHA computation pipeline. It supports configurable data types, tiling parameters,
|
|
and various attention masking options. Results can be validated against a PyTorch reference
|
|
implementation or run multiple times for performance measurement.
|
|
|
|
The implementation leverages specialized tensor memory operations and efficient math
|
|
operations optimized for Hopper architecture, including pipelined computation stages
|
|
for maximum throughput.
|
|
|
|
:param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length,
|
|
H=number of heads, D=head dimension.
|
|
If S_q is a tuple, it is the variable sequence length.
|
|
:type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int]
|
|
:param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length,
|
|
H_k=number of key heads (H must be divisible by H_k), D=head dimension.
|
|
If S_k is a tuple, it is the variable sequence length.
|
|
:type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int]
|
|
:param in_dtype: Input data type for query, key and value tensors
|
|
:type in_dtype: Type[cutlass.Numeric]
|
|
:param out_dtype: Output data type for attention output
|
|
:type out_dtype: Type[cutlass.Numeric]
|
|
:param qk_acc_dtype: Accumulator data type for query-key matrix multiplication
|
|
:type qk_acc_dtype: Type[cutlass.Numeric]
|
|
:param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication
|
|
:type pv_acc_dtype: Type[cutlass.Numeric]
|
|
:param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N)
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param is_persistent: Whether to use persistent kernel optimization
|
|
:type is_persistent: bool
|
|
:param is_causal: Whether to apply causal masking
|
|
:type is_causal: bool
|
|
:param bottom_right_align: Whether to use bottom right align, under this settion, the end of q is aligned with the end of k.
|
|
:type bottom_right_align: bool
|
|
:param scale_q: Scaling factor for query tensor
|
|
:type scale_q: float
|
|
:param scale_k: Scaling factor for key tensor
|
|
:type scale_k: float
|
|
:param scale_v: Scaling factor for value tensor
|
|
:type scale_v: float
|
|
:param inv_scale_o: Inverse scaling factor for output tensor
|
|
:type inv_scale_o: float
|
|
:param scale_softmax: Attention score scaling factor (defaults to 1/sqrt(D) if set to 0)
|
|
:type scale_softmax: float
|
|
:param window_size: Sliding window size (left, right) for attention masking. Controls which positions each query can attend to. Negative values disable windowing.
|
|
:type window_size: Tuple[int, int]
|
|
:param tolerance: Maximum acceptable error for validation
|
|
:type tolerance: float
|
|
:param warmup_iterations: Number of warmup iterations
|
|
:type warmup_iterations: int
|
|
:param iterations: Number of iterations to run for performance testing
|
|
:type iterations: int
|
|
:param skip_ref_check: Skip validation against reference implementation
|
|
:type skip_ref_check: bool
|
|
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
|
|
:type use_cold_l2: bool
|
|
|
|
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
|
:raises RuntimeError: If GPU is unavailable for computation
|
|
:return: Execution time of the FMHA kernel in microseconds
|
|
:rtype: float
|
|
"""
|
|
print("Running Hopper SM90 FMHA test with:")
|
|
print(f" q_shape: {q_shape}")
|
|
print(f" k_shape: {k_shape}")
|
|
print(f" in_dtype: {in_dtype}")
|
|
print(f" out_dtype: {out_dtype}")
|
|
print(f" qk_acc_dtype: {qk_acc_dtype}")
|
|
print(f" pv_acc_dtype: {pv_acc_dtype}")
|
|
print(f" mma_tiler_mn: {mma_tiler_mn}")
|
|
print(f" is_persistent: {is_persistent}")
|
|
print(f" is_causal: {is_causal}")
|
|
print(f" bottom_right_align: {bottom_right_align}")
|
|
print(f" scale_q: {scale_q}")
|
|
print(f" scale_k: {scale_k}")
|
|
print(f" scale_v: {scale_v}")
|
|
print(f" inv_scale_o: {inv_scale_o}")
|
|
print(f" scale_softmax: {scale_softmax}")
|
|
print(f" window_size: {window_size}")
|
|
print(f" tolerance: {tolerance}")
|
|
print(f" skip_ref_check: {skip_ref_check}")
|
|
print(f" use_cold_l2: {use_cold_l2}")
|
|
|
|
# Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero)
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
ret, msg = HopperFusedMultiHeadAttentionForward.can_implement(
|
|
q_shape,
|
|
k_shape,
|
|
in_dtype,
|
|
out_dtype,
|
|
qk_acc_dtype,
|
|
pv_acc_dtype,
|
|
mma_tiler_mn,
|
|
is_persistent,
|
|
scale_softmax,
|
|
window_size,
|
|
iterations,
|
|
)
|
|
if not ret:
|
|
raise TypeError(msg)
|
|
|
|
# Unpack parameters
|
|
b, s_q, h, d = q_shape
|
|
b_, s_k, h_k, d_ = k_shape
|
|
window_size_left, window_size_right = window_size
|
|
if window_size_left == -1:
|
|
window_size_left = None
|
|
if window_size_right == -1:
|
|
window_size_right = None
|
|
|
|
h_r = h // h_k
|
|
|
|
torch.manual_seed(1111)
|
|
|
|
def create_and_permute_tensor(
|
|
b, s, h_k, h_r, d, dtype, is_dynamic_layout=True, tensor_name=""
|
|
):
|
|
# (b, s, h_k, h_r, d) -> (s, d, h_r, h_k, b)
|
|
# torch SPDA order is (h_k, h_r), then kernel is (h_r, h_k)
|
|
shape = (b, s, h_k, h_r, d)
|
|
permute_order = (1, 4, 3, 2, 0)
|
|
is_fp8 = dtype in {cutlass.Float8E4M3FN}
|
|
leading_dim = 1
|
|
if is_fp8 and tensor_name == "v":
|
|
permute_order = (4, 1, 3, 2, 0)
|
|
leading_dim = 0
|
|
shape = (b, d, h_k, h_r, s)
|
|
|
|
# torch does not support fp8 type
|
|
torch_dtype = cutlass.torch.dtype(dtype) if not is_fp8 else torch.int8
|
|
|
|
# Create dtype torch tensor (cpu)
|
|
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
|
|
shape,
|
|
torch_dtype,
|
|
permute_order=permute_order,
|
|
init_type=cutlass.torch.TensorInitType.RANDOM,
|
|
init_config=cutlass.torch.RandomInitConfig(
|
|
min_val=-2,
|
|
max_val=2,
|
|
),
|
|
)
|
|
# Create dtype torch tensor (gpu)
|
|
torch_tensor_gpu = torch_tensor_cpu.cuda()
|
|
|
|
# Create f32 torch tensor (cpu)
|
|
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
|
|
# Create dtype cute tensor (gpu)
|
|
cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16)
|
|
cute_tensor.element_type = dtype
|
|
if is_dynamic_layout:
|
|
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
|
cute_tensor = cutlass_torch.convert_cute_tensor(
|
|
f32_torch_tensor,
|
|
cute_tensor,
|
|
dtype,
|
|
is_dynamic_layout=is_dynamic_layout,
|
|
)
|
|
|
|
return f32_torch_tensor, cute_tensor, torch_tensor_gpu
|
|
|
|
q_ref, q_tensor, q_torch = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, d, in_dtype, is_dynamic_layout=True
|
|
)
|
|
k_ref, k_tensor, k_torch = create_and_permute_tensor(
|
|
b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True
|
|
)
|
|
v_ref, v_tensor, v_torch = create_and_permute_tensor(
|
|
b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True, tensor_name="v"
|
|
)
|
|
o_ref, o_tensor, o_torch = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, d, out_dtype, is_dynamic_layout=True
|
|
)
|
|
lse_ref, lse_tensor, lse_torch = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, 1, qk_acc_dtype, is_dynamic_layout=True
|
|
)
|
|
|
|
mma_tiler = (*mma_tiler_mn, d)
|
|
|
|
mask_type = fmha_utils.MaskEnum.WINDOW_MASK
|
|
if bottom_right_align:
|
|
mask_type = fmha_utils.MaskEnum.WINDOW_MASK_INFERENCE
|
|
if is_causal:
|
|
window_size_right = 0
|
|
elif window_size_left is None and window_size_right is None:
|
|
if s_k % mma_tiler_mn[1] != 0:
|
|
mask_type = fmha_utils.MaskEnum.RESIDUAL_MASK
|
|
|
|
# To avoid mask out the whole row which results in NaN in softmax
|
|
def check_seqlen_valid(
|
|
s_q, s_k, window_size_left, window_size_right, bottom_right_align
|
|
):
|
|
for i in range(s_q):
|
|
offset = 0 if not bottom_right_align else s_k - s_q
|
|
|
|
s_q_start = 0 if window_size_left is None else i + offset - window_size_left
|
|
s_q_end = (
|
|
s_q if window_size_right is None else i + offset + window_size_right
|
|
)
|
|
s_q_min = max(s_q_start, 0)
|
|
s_q_max = min(s_q_end, s_k)
|
|
|
|
if s_q_max - s_q_min == 0 and (i != 0 and i != s_q - 1):
|
|
return False
|
|
return True
|
|
|
|
need_check_seqlen_valid = (
|
|
window_size_left is not None or window_size_right is not None
|
|
)
|
|
if need_check_seqlen_valid and not check_seqlen_valid(
|
|
s_q,
|
|
s_k,
|
|
window_size_left,
|
|
window_size_right,
|
|
bottom_right_align,
|
|
):
|
|
raise ValueError("sliding window doesn't support current setting")
|
|
|
|
fmha = HopperFusedMultiHeadAttentionForward(
|
|
qk_acc_dtype,
|
|
pv_acc_dtype,
|
|
mma_tiler,
|
|
is_persistent,
|
|
mask_type,
|
|
)
|
|
|
|
# Get current CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.current_stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
|
|
if scale_softmax == 0.0: # default to 1/sqrt(d)
|
|
scale_softmax = 1.0 / math.sqrt(q_shape[1])
|
|
|
|
scale_softmax = scale_q * scale_k * scale_softmax
|
|
|
|
LOG2_E = 1.4426950408889634074
|
|
scale_softmax_log2 = scale_softmax * LOG2_E
|
|
scale_output = scale_v * inv_scale_o
|
|
|
|
print("Compiling kernel with cute.compile ...")
|
|
start_time = time.time()
|
|
# compile fmha kernel
|
|
compiled_fmha = cute.compile(
|
|
fmha,
|
|
q_tensor,
|
|
k_tensor,
|
|
v_tensor,
|
|
o_tensor,
|
|
lse_tensor,
|
|
scale_softmax_log2,
|
|
scale_softmax,
|
|
scale_output,
|
|
(
|
|
window_size_left
|
|
if window_size_left is None
|
|
else cutlass.Int32(window_size_left)
|
|
),
|
|
(
|
|
window_size_right
|
|
if window_size_right is None
|
|
else cutlass.Int32(window_size_right)
|
|
),
|
|
current_stream,
|
|
)
|
|
compilation_time = time.time() - start_time
|
|
print(f"Compilation time: {compilation_time:.4f} seconds")
|
|
|
|
def run_torch_fmha(
|
|
q,
|
|
k,
|
|
v,
|
|
scale_softmax=1.0,
|
|
scale_output=1.0,
|
|
is_causal=False,
|
|
window_size_left=None,
|
|
window_size_right=None,
|
|
):
|
|
s_q, d, h_r, h_k, b = q.shape
|
|
s_k = k.shape[0]
|
|
|
|
# broadcast k and v to have the same shape as q
|
|
k = k.expand(s_k, d, h_r, h_k, b)
|
|
v = v.expand(s_k, d, h_r, h_k, b)
|
|
|
|
q_tmp = q.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_q, d)
|
|
k_tmp = k.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d)
|
|
v_tmp = v.permute(4, 2, 3, 0, 1).contiguous().view(b, -1, s_k, d)
|
|
|
|
cur_S = torch.einsum("bhqd,bhkd->bhqk", q_tmp, k_tmp)
|
|
|
|
# For causal masking, disable right-side windowing (no future tokens)
|
|
if is_causal:
|
|
window_size_right = 0
|
|
|
|
if window_size_left is not None or window_size_right is not None:
|
|
q_coords = torch.arange(0, s_q).cuda().view(-1, 1)
|
|
k_coords = torch.arange(0, s_k).cuda().view(1, -1)
|
|
offset = 0 if not bottom_right_align else s_k - s_q
|
|
if window_size_left is None:
|
|
_mask = k_coords > q_coords + offset + window_size_right
|
|
elif window_size_right is None:
|
|
_mask = k_coords < q_coords + offset - window_size_left
|
|
else:
|
|
_mask = (k_coords > q_coords + offset + window_size_right) | (
|
|
k_coords < q_coords + offset - window_size_left
|
|
)
|
|
cur_S = cur_S.masked_fill(_mask.cpu(), -torch.inf)
|
|
|
|
p_tmp = torch.softmax(cur_S * scale_softmax, dim=-1)
|
|
ref = torch.einsum("bhsl,bhld->bhsd", p_tmp, v_tmp)
|
|
ref = ref.view(b, h_r, h_k, s_q, d).permute(3, 4, 1, 2, 0) * scale_output
|
|
|
|
cur_S_max = torch.max(cur_S, dim=-1, keepdim=True).values
|
|
cur_S_sum = torch.sum(
|
|
torch.exp2((cur_S - cur_S_max) * scale_softmax_log2), dim=-1, keepdim=True
|
|
)
|
|
|
|
lse = cur_S_max * scale_softmax + torch.log(cur_S_sum)
|
|
|
|
# [B, H, Q, 1]->[Q,1,H,B]
|
|
lse = lse.permute(2, 3, 1, 0).contiguous().view(s_q, 1, h_r, h_k, b)
|
|
|
|
return ref, lse
|
|
|
|
if not skip_ref_check:
|
|
# Execute kernel oneshot for verify
|
|
compiled_fmha(
|
|
q_tensor,
|
|
k_tensor,
|
|
v_tensor,
|
|
o_tensor,
|
|
lse_tensor,
|
|
scale_softmax_log2,
|
|
scale_softmax,
|
|
scale_output,
|
|
(
|
|
window_size_left
|
|
if window_size_left is None
|
|
else cutlass.Int32(window_size_left)
|
|
),
|
|
(
|
|
window_size_right
|
|
if window_size_right is None
|
|
else cutlass.Int32(window_size_right)
|
|
),
|
|
current_stream,
|
|
)
|
|
|
|
print("Verifying results...")
|
|
o_ref, lse_ref = run_torch_fmha(
|
|
q_ref,
|
|
k_ref,
|
|
v_ref,
|
|
scale_softmax,
|
|
scale_output,
|
|
is_causal,
|
|
window_size_left,
|
|
window_size_right,
|
|
)
|
|
|
|
# convert o back to f32 for comparison
|
|
o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like(
|
|
torch.empty(*o_torch.shape, dtype=torch.float32),
|
|
cutlass.Float32,
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
cute.testing.convert(o_tensor, o_fp32)
|
|
|
|
ref_o_f32, ref_o_f32_torch = cutlass_torch.cute_tensor_like(
|
|
o_ref,
|
|
cutlass.Float32,
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
|
|
if out_dtype.is_float and out_dtype.width <= 8:
|
|
ref_narrow_precision, _ = cutlass_torch.cute_tensor_like(
|
|
torch.empty_strided(o_ref.shape, o_ref.stride(), dtype=torch.uint8),
|
|
out_dtype,
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
|
|
# convert ref : f32 -> fp4/fp8 -> f32
|
|
cute.testing.convert(ref_o_f32, ref_narrow_precision)
|
|
cute.testing.convert(ref_narrow_precision, ref_o_f32)
|
|
|
|
# check output ref
|
|
torch.testing.assert_close(
|
|
o_fp32_torch, ref_o_f32_torch, atol=tolerance, rtol=1e-05
|
|
)
|
|
|
|
# check lse ref
|
|
lse_result = lse_torch.cpu()
|
|
torch.testing.assert_close(lse_result, lse_ref, atol=tolerance, rtol=1e-05)
|
|
|
|
print("Results verified successfully!")
|
|
|
|
def generate_tensors():
|
|
_, q_tensor_workspace, _ = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, d, in_dtype, is_dynamic_layout=True
|
|
)
|
|
_, k_tensor_workspace, _ = create_and_permute_tensor(
|
|
b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True
|
|
)
|
|
_, v_tensor_workspace, _ = create_and_permute_tensor(
|
|
b, s_k, h_k, 1, d, in_dtype, is_dynamic_layout=True, tensor_name="v"
|
|
)
|
|
_, o_tensor_workspace, _ = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, d, out_dtype, is_dynamic_layout=True
|
|
)
|
|
_, lse_tensor_workspace, _ = create_and_permute_tensor(
|
|
b, s_q, h_k, h_r, 1, qk_acc_dtype, is_dynamic_layout=True
|
|
)
|
|
|
|
return testing.JitArguments(
|
|
q_tensor_workspace,
|
|
k_tensor_workspace,
|
|
v_tensor_workspace,
|
|
o_tensor_workspace,
|
|
lse_tensor_workspace,
|
|
scale_softmax_log2,
|
|
scale_softmax,
|
|
scale_output,
|
|
(
|
|
window_size_left
|
|
if window_size_left is None
|
|
else cutlass.Int32(window_size_left)
|
|
),
|
|
(
|
|
window_size_right
|
|
if window_size_right is None
|
|
else cutlass.Int32(window_size_right)
|
|
),
|
|
current_stream,
|
|
)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
|
|
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
|
|
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
|
|
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
|
|
lse_torch_effective = lse_torch.values() if lse_torch.is_nested else lse_torch
|
|
one_workspace_bytes = (
|
|
q_torch_effective.numel() * q_torch_effective.element_size()
|
|
+ k_torch_effective.numel() * k_torch_effective.element_size()
|
|
+ v_torch_effective.numel() * v_torch_effective.element_size()
|
|
+ o_torch_effective.numel() * o_torch_effective.element_size()
|
|
+ lse_torch_effective.numel() * lse_torch_effective.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
exec_time = testing.benchmark(
|
|
compiled_fmha,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=current_stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return exec_time # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str):
|
|
try:
|
|
return [int(x.strip()) for x in s.split(",")]
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="""
|
|
This example showcases the use of CUTE DSL builders to easily construct fused multi-head attention forward-pass kernels targeting NVIDIA's Hopper architecture.
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--in_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
help="Input data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--out_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
help="Output data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--qk_acc_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float32,
|
|
help="QK accumulator data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pv_acc_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float32,
|
|
help="PV accumulator data type",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mma_tile_shape_mn",
|
|
type=parse_comma_separated_ints,
|
|
default=[64, 128],
|
|
help="MMA tile shape (M, N)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_persistent",
|
|
action="store_true",
|
|
help="Is persistent",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--is_causal",
|
|
action="store_true",
|
|
help="Whether to use causal mask",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--q_shape",
|
|
type=parse_comma_separated_ints,
|
|
default=[1, 128, 16, 128],
|
|
help="Shape of Q (B, S_q, H, D)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--k_shape",
|
|
type=parse_comma_separated_ints,
|
|
default=[1, 128, 16, 128],
|
|
help="Shape of K (B, S_k, H_k, D)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale_q",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factors to dequantize Q",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale_k",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factors to dequantize K",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale_v",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factors to dequantize V",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--inv_scale_o",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factor to quantize O",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale_softmax",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scaling factor to scale S (i.e. Q*K); if zero, defaults to 1/sqrt(D)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--window_size",
|
|
type=parse_comma_separated_ints,
|
|
default=(-1, -1),
|
|
help="Sliding window size (left, right) for attention masking.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bottom_right_align",
|
|
action="store_true",
|
|
help="Whether to use bottom right align, under this settion, the end of q is aligned with the end of k.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--warmup_iterations",
|
|
type=int,
|
|
default=0,
|
|
help="Number of iterations for warmup",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations after warmup",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--skip_ref_check",
|
|
action="store_true",
|
|
help="Skip reference check",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if len(args.q_shape) != 4:
|
|
parser.error("--q_shape must contain exactly 4 values")
|
|
|
|
if len(args.k_shape) != 4:
|
|
parser.error("--k_shape must contain exactly 4 values")
|
|
|
|
if len(args.mma_tile_shape_mn) != 2:
|
|
parser.error("--mma_tile_shape_mn must contain exactly 2 values")
|
|
|
|
run(
|
|
args.q_shape,
|
|
args.k_shape,
|
|
args.in_dtype,
|
|
args.out_dtype,
|
|
args.qk_acc_dtype,
|
|
args.pv_acc_dtype,
|
|
args.mma_tile_shape_mn,
|
|
args.is_persistent,
|
|
args.is_causal,
|
|
args.bottom_right_align,
|
|
args.scale_q,
|
|
args.scale_k,
|
|
args.scale_v,
|
|
args.inv_scale_o,
|
|
args.scale_softmax,
|
|
args.window_size,
|
|
args.tolerance,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
|
|
print("PASS")
|