1358 lines
54 KiB
Python
1358 lines
54 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
from types import SimpleNamespace
|
|
from typing import Type, Union, Callable
|
|
|
|
import torch
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.cute.testing as testing
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cute.nvgpu import cpasync, warp
|
|
import cutlass.torch as cutlass_torch
|
|
from cutlass.cute.runtime import from_dlpack
|
|
import cutlass.utils as utils
|
|
|
|
"""
|
|
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
|
|
|
|
- Matrix Q is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension
|
|
- Matrix K is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension
|
|
- Matrix V is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension
|
|
- Matrix O is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension
|
|
|
|
This kernel supports the following features:
|
|
- Utilizes CpAsync for efficient memory operations
|
|
- Utilizes Ampere's tensor core for matrix multiply-accumulate (MMA) operations
|
|
- Utilizes register pipeline to overlap shared memory-to-register transfers with computations.
|
|
- Leverages DSL to implement an integrated online softmax fusion pattern.
|
|
|
|
This kernel works as follows:
|
|
1. Load Q and K matrices from global memory (GMEM) to shared memory (SMEM) using CpAsync operations.
|
|
2. Perform matrix multiply-accumulate (MMA) operations using tensor core instructions to compute intermediate result S.
|
|
3. Apply padding mask or causal mask to S during initial iterations.
|
|
4. Apply online softmax to S and rescale O using results from previous iteration.
|
|
5. Load V matrices and perform matrix multiply-accumulate (MMA) operations to compute final result O.
|
|
6. Normalize O after all iterations complete and store result back to global memory (GMEM).
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/ampere/flash_attention_v2.py \
|
|
--dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \
|
|
--num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \
|
|
--num_head 16 --softmax_scale 1.0 --is_causal
|
|
|
|
The above command configures the model to use float16 for inputs and outputs. The problem dimensions
|
|
are set to: batch size of 1, query sequence length of 1280, key sequence length of 1536, head dimension
|
|
of 128, and 16 attention heads. The softmax scale is set to 1.0 and causal masking is enabled. The computation
|
|
uses tiles of size 128x128 for m and n dimensions, and utilizes 128 parallel threads.
|
|
|
|
To collect the performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/ampere/flash_attention_v2.py \
|
|
--dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \
|
|
--num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \
|
|
--num_head 16 --softmax_scale 1.0 --is_causal --skip_ref_check
|
|
|
|
There are some constraints for this example:
|
|
* Only fp16 and bf16 data types are supported.
|
|
* The contiguous dimension of each tensor must be at least 16 bytes aligned.
|
|
* The log-sum-exp(for training) is not computed in the kernel.
|
|
* The values of `m_block_size`, `n_block_size`, and `head_dim` must be selected to stay within shared memory capacity limits.
|
|
* `m_block_size * 2` must be divisible by `num_threads`, otherwise the kernel will not be able to get the correct result.
|
|
"""
|
|
|
|
|
|
class FlashAttentionForwardAmpere:
|
|
def __init__(
|
|
self,
|
|
head_dim: int,
|
|
m_block_size: int = 128,
|
|
n_block_size: int = 128,
|
|
num_threads: int = 128,
|
|
is_causal: bool = False,
|
|
):
|
|
"""Initializes the configuration for a flash attention v2 kernel.
|
|
|
|
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
|
|
should be a multiple of 8.
|
|
|
|
:param head_dim: head dimension
|
|
:type head_dim: int
|
|
:param m_block_size: m block size
|
|
:type m_block_size: int
|
|
:param n_block_size: n block size
|
|
:type n_block_size: int
|
|
:param num_threads: number of threads
|
|
:type num_threads: int
|
|
:param is_causal: is causal
|
|
"""
|
|
self._head_dim = head_dim
|
|
self._m_block_size = m_block_size
|
|
self._n_block_size = n_block_size
|
|
# padding head_dim to a multiple of 32 as k_block_size
|
|
self._head_dim_padded = (head_dim + 31) // 32 * 32
|
|
self._num_threads = num_threads
|
|
self._is_causal = is_causal
|
|
|
|
@staticmethod
|
|
def can_implement(
|
|
dtype, head_dim, m_block_size, n_block_size, num_threads, is_causal
|
|
) -> bool:
|
|
"""Check if the kernel can be implemented with the given parameters.
|
|
|
|
:param dtype: data type
|
|
:type dtype: cutlass.Numeric
|
|
:param head_dim: head dimension
|
|
:type head_dim: int
|
|
:param m_block_size: m block size
|
|
:type m_block_size: int
|
|
:param n_block_size: n block size
|
|
:type n_block_size: int
|
|
:param num_threads: number of threads
|
|
:type num_threads: int
|
|
:param is_causal: is causal
|
|
:type is_causal: bool
|
|
|
|
:return: True if the kernel can be implemented, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
# Check if data type is fp16 or bf16
|
|
if dtype != cutlass.Float16 and dtype != cutlass.BFloat16:
|
|
return False
|
|
|
|
# Check if head dimension is a multiple of 8
|
|
if head_dim % 8 != 0:
|
|
return False
|
|
|
|
# Check if number of threads is a multiple of 32
|
|
if num_threads % 32 != 0:
|
|
return False
|
|
|
|
# Check if block size setting is out of shared memory capacity
|
|
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
|
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
|
|
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
|
|
if smem_usage > smem_capacity:
|
|
return False
|
|
|
|
# Check if twice the block size is divisible by the number of threads
|
|
if (m_block_size * 2) % num_threads != 0:
|
|
return False
|
|
|
|
return True
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
mQ: cute.Tensor,
|
|
mK: cute.Tensor,
|
|
mV: cute.Tensor,
|
|
mO: cute.Tensor,
|
|
softmax_scale: cutlass.Float32,
|
|
stream: cuda.CUstream,
|
|
):
|
|
"""Configures and launches the flash attention v2 kernel.
|
|
|
|
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
|
|
(batch_size, seqlen_q, num_head, head_dim):(seqlen_q * num_head * head_dim, num_head * head_dim, head_dim, 1)
|
|
|
|
Prepares the shared memory layout, tiled copy atoms, tiled mma and shared memory storage.
|
|
Then launches the kernel function with the prepared parameters.
|
|
|
|
:param mQ: query tensor
|
|
:type mQ: cute.Tensor
|
|
:param mK: key tensor
|
|
:type mK: cute.Tensor
|
|
:param mV: value tensor
|
|
:type mV: cute.Tensor
|
|
:param mO: output tensor
|
|
:type mO: cute.Tensor
|
|
:param softmax_scale: softmax scale
|
|
:type softmax_scale: cutlass.Float32
|
|
"""
|
|
# Get the data type and check if it is fp16 or bf16
|
|
if cutlass.const_expr(
|
|
not (
|
|
mQ.element_type == mK.element_type == mV.element_type == mO.element_type
|
|
)
|
|
):
|
|
raise TypeError("All tensors must have the same data type")
|
|
if cutlass.const_expr(
|
|
not (
|
|
mQ.element_type == cutlass.Float16
|
|
or mQ.element_type == cutlass.BFloat16
|
|
)
|
|
):
|
|
raise TypeError("Only Float16 or BFloat16 is supported")
|
|
self._dtype: Type[cutlass.Numeric] = mQ.element_type
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Shared memory layout: Q/K/V
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32
|
|
swizzle_bits = 3 if smem_k_block_size == 64 else 2
|
|
sQ_layout_atom = cute.make_composed_layout(
|
|
cute.make_swizzle(swizzle_bits, 3, 3),
|
|
0,
|
|
cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)),
|
|
)
|
|
sQ_layout = cute.tile_to_shape(
|
|
sQ_layout_atom,
|
|
(self._m_block_size, self._head_dim_padded),
|
|
(0, 1),
|
|
)
|
|
|
|
sKV_layout_atom = sQ_layout_atom
|
|
sKV_layout = cute.tile_to_shape(
|
|
sKV_layout_atom,
|
|
(self._n_block_size, self._head_dim_padded),
|
|
(0, 1),
|
|
)
|
|
|
|
sO_layout = sQ_layout
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
sQ: cute.struct.Align[
|
|
cute.struct.MemRange[self._dtype, cute.cosize(sQ_layout)], 1024
|
|
]
|
|
sK: cute.struct.Align[
|
|
cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024
|
|
]
|
|
sV: cute.struct.Align[
|
|
cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024
|
|
]
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# GMEM Tiled copy:
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Thread layouts for copies
|
|
universal_copy_bits = 128
|
|
async_copy_elems = universal_copy_bits // self._dtype.width
|
|
# atom_async_copy: async copy atom for QKV load
|
|
atom_async_copy = cute.make_copy_atom(
|
|
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
|
self._dtype,
|
|
num_bits_per_copy=universal_copy_bits,
|
|
)
|
|
# atom_universal_copy: universal copy atom for O store
|
|
atom_universal_copy = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
self._dtype,
|
|
num_bits_per_copy=universal_copy_bits,
|
|
)
|
|
# tQKV_layout: thread layout for QKV load
|
|
tQKV_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
|
|
tQKV_layout = cute.make_layout(
|
|
(self._num_threads // tQKV_shape_dim_1, tQKV_shape_dim_1),
|
|
stride=(tQKV_shape_dim_1, 1),
|
|
)
|
|
# tO_layout: thread layout for O store
|
|
tO_layout = tQKV_layout
|
|
|
|
# Value layouts for copies
|
|
vQKV_layout = cute.make_layout((1, async_copy_elems))
|
|
vO_layout = vQKV_layout
|
|
|
|
# gmem_tiled_copy_QKV: tiled copy for QKV load
|
|
gmem_tiled_copy_QKV = cute.make_tiled_copy_tv(
|
|
atom_async_copy, tQKV_layout, vQKV_layout
|
|
)
|
|
# gmem_tiled_copy_O: tiled copy for O store
|
|
gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
|
atom_universal_copy, tO_layout, vO_layout
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Tiled mma
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
tiled_mma = cute.make_tiled_mma(
|
|
warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)),
|
|
(self._num_threads // 32, 1, 1),
|
|
permutation_mnk=(self._num_threads // 32 * 16, 16, 16),
|
|
)
|
|
|
|
# grid_dim: (m_block, batch_size, num_head)
|
|
grid_dim = (
|
|
cute.ceil_div(mQ.shape[1], self._m_block_size),
|
|
cute.size(mQ.shape[0]),
|
|
cute.size(mQ.shape[2]),
|
|
)
|
|
LOG2_E = 1.4426950408889634074
|
|
softmax_scale_log2 = softmax_scale * LOG2_E
|
|
self.kernel(
|
|
mQ,
|
|
mK,
|
|
mV,
|
|
mO,
|
|
softmax_scale_log2,
|
|
sQ_layout,
|
|
sKV_layout,
|
|
sO_layout,
|
|
gmem_tiled_copy_QKV,
|
|
gmem_tiled_copy_O,
|
|
tiled_mma,
|
|
SharedStorage,
|
|
).launch(
|
|
grid=grid_dim,
|
|
block=[self._num_threads, 1, 1],
|
|
smem=SharedStorage.size_in_bytes(),
|
|
stream=stream,
|
|
)
|
|
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
mQ: cute.Tensor,
|
|
mK: cute.Tensor,
|
|
mV: cute.Tensor,
|
|
mO: cute.Tensor,
|
|
softmax_scale_log2: cutlass.Float32,
|
|
sQ_layout: cute.ComposedLayout,
|
|
sKV_layout: cute.ComposedLayout,
|
|
sO_layout: cute.ComposedLayout,
|
|
gmem_tiled_copy_QKV: cute.TiledCopy,
|
|
gmem_tiled_copy_O: cute.TiledCopy,
|
|
tiled_mma: cute.TiledMma,
|
|
SharedStorage: cutlass.Constexpr,
|
|
):
|
|
"""Kernel function for flash attention v2.
|
|
|
|
:param mQ: query tensor
|
|
:type mQ: cute.Tensor
|
|
:param mK: key tensor
|
|
:type mK: cute.Tensor
|
|
:param mV: value tensor
|
|
:type mV: cute.Tensor
|
|
:param mO: output tensor
|
|
:type mO: cute.Tensor
|
|
:param softmax_scale_log2: softmax scale log2
|
|
:type softmax_scale_log2: cutlass.Float32
|
|
:param sQ_layout: query layout
|
|
:type sQ_layout: cute.ComposedLayout
|
|
:param sKV_layout: key/value layout
|
|
:type sKV_layout: cute.ComposedLayout
|
|
:param sO_layout: output layout
|
|
:type sO_layout: cute.ComposedLayout
|
|
:param gmem_tiled_copy_QKV: tiled copy for QKV load
|
|
:type gmem_tiled_copy_QKV: cute.TiledCopy
|
|
:param gmem_tiled_copy_O: tiled copy for O store
|
|
:type gmem_tiled_copy_O: cute.TiledCopy
|
|
:param tiled_mma: tiled mma
|
|
:type tiled_mma: cute.TiledMma
|
|
:param SharedStorage: shared storage
|
|
:type SharedStorage: cutlass.Constexpr
|
|
"""
|
|
# Thread index, block index
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
m_block, batch_size, num_head = cute.arch.block_idx()
|
|
|
|
n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size)
|
|
if self._is_causal:
|
|
n_block_max = min(
|
|
cute.ceil_div(
|
|
(m_block + 1) * self._m_block_size,
|
|
self._n_block_size,
|
|
),
|
|
n_block_max,
|
|
)
|
|
n_block = n_block_max - 1
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Get the appropriate tiles for this thread block.
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# (m_block_size, head_dim)
|
|
gQ = cute.local_tile(
|
|
mQ[batch_size, None, num_head, None],
|
|
(self._m_block_size, self._head_dim_padded),
|
|
(m_block, 0),
|
|
)
|
|
# (n_block_size, head_dim, n_block)
|
|
gK = cute.local_tile(
|
|
mK[batch_size, None, num_head, None],
|
|
(self._n_block_size, self._head_dim_padded),
|
|
(None, 0),
|
|
)
|
|
# (n_block_size, head_dim, n_block)
|
|
gV = cute.local_tile(
|
|
mV[batch_size, None, num_head, None],
|
|
(self._n_block_size, self._head_dim_padded),
|
|
(None, 0),
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Get shared memory buffer
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
smem = cutlass.utils.SmemAllocator()
|
|
|
|
storage = smem.allocate(SharedStorage)
|
|
sQ = storage.sQ.get_tensor(sQ_layout)
|
|
sK = storage.sK.get_tensor(sKV_layout)
|
|
sV = storage.sV.get_tensor(sKV_layout)
|
|
|
|
# Transpose view of V to tensor with layout (head_dim, n_block_size) for tiled mma
|
|
sVt = cute.composition(
|
|
sV,
|
|
cute.make_layout(
|
|
(self._head_dim_padded, self._n_block_size),
|
|
stride=(self._n_block_size, 1),
|
|
),
|
|
)
|
|
|
|
gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_slice(tidx)
|
|
# (CPY_Atom, CPY_M, CPY_K)
|
|
tQgQ = gmem_thr_copy_QKV.partition_S(gQ)
|
|
tQsQ = gmem_thr_copy_QKV.partition_D(sQ)
|
|
# (CPY_Atom, CPY_N, CPY_K, n_block)
|
|
tKgK = gmem_thr_copy_QKV.partition_S(gK)
|
|
tKsK = gmem_thr_copy_QKV.partition_D(sK)
|
|
# (CPY_Atom, CPY_N, CPY_K, n_block)
|
|
tVgV = gmem_thr_copy_QKV.partition_S(gV)
|
|
tVsV = gmem_thr_copy_QKV.partition_D(sV)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Tile MMA compute thread partitions and allocate accumulators
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
thr_mma = tiled_mma.get_slice(tidx)
|
|
tSrQ = thr_mma.make_fragment_A(thr_mma.partition_A(sQ))
|
|
tSrK = thr_mma.make_fragment_B(thr_mma.partition_B(sK))
|
|
tOrVt = thr_mma.make_fragment_B(thr_mma.partition_B(sVt))
|
|
acc_shape_O = thr_mma.partition_shape_C(
|
|
(self._m_block_size, self._head_dim_padded)
|
|
)
|
|
acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32)
|
|
acc_O.fill(0.0)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Smem copy atom tiling
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
smem_copy_atom_Q = cute.make_copy_atom(
|
|
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
|
|
self._dtype,
|
|
)
|
|
smem_copy_atom_K = cute.make_copy_atom(
|
|
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
|
|
self._dtype,
|
|
)
|
|
smem_copy_atom_V = cute.make_copy_atom(
|
|
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
|
self._dtype,
|
|
)
|
|
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
|
|
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
|
|
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
|
|
|
|
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
|
|
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
|
|
smem_thr_copy_V = smem_tiled_copy_V.get_slice(tidx)
|
|
|
|
tSsQ = smem_thr_copy_Q.partition_S(sQ)
|
|
tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ)
|
|
tSsK = smem_thr_copy_K.partition_S(sK)
|
|
tSrK_copy_view = smem_thr_copy_K.retile(tSrK)
|
|
tOsVt = smem_thr_copy_V.partition_S(sVt)
|
|
tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
|
# of tile_shape
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Construct identity layout for Q and KV
|
|
mcQ = cute.make_identity_tensor(mQ.layout.shape)
|
|
mcKV = cute.make_identity_tensor(mK.layout.shape)
|
|
cQ = cute.local_tile(
|
|
mcQ[batch_size, None, num_head, None],
|
|
(self._m_block_size, self._head_dim_padded),
|
|
(m_block, 0),
|
|
)
|
|
cKV = cute.local_tile(
|
|
mcKV[batch_size, None, num_head, None],
|
|
(self._n_block_size, self._head_dim_padded),
|
|
(n_block, 0),
|
|
)
|
|
|
|
# Repeat the partitioning with identity layouts
|
|
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
|
|
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
|
|
# Allocate predicate tensors for m and n, here we only allocate the tile of k, and do special process for mn.
|
|
# This is to reduce register pressure and gets 2-3% performance gain compared with allocating the whole tile.
|
|
tQpQ = cute.make_fragment(
|
|
cute.make_layout(
|
|
(
|
|
tQsQ.shape[0][1],
|
|
cute.size(tQsQ, mode=[1]),
|
|
cute.size(tQsQ, mode=[2]),
|
|
),
|
|
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
tKVpKV = cute.make_fragment(
|
|
cute.make_layout(
|
|
(
|
|
tKsK.shape[0][1],
|
|
cute.size(tKsK, mode=[1]),
|
|
cute.size(tKsK, mode=[2]),
|
|
),
|
|
stride=(cute.size(tKsK, mode=[2]), 0, 1),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
# Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile.
|
|
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
|
|
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
|
|
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
|
|
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
|
|
)
|
|
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
|
|
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
|
|
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
|
|
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
|
|
)
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Prefetch Prologue
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Start async loads of the last mn-tile, where we take care of the mn residue
|
|
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
|
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
|
|
cute.copy(
|
|
gmem_tiled_copy_QKV,
|
|
tQgQ[None, m, None],
|
|
tQsQ[None, m, None],
|
|
pred=tQpQ[None, m, None],
|
|
)
|
|
else:
|
|
# Clear the smem tiles to account for predicated off loads
|
|
tQsQ[None, m, None].fill(0)
|
|
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
|
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
|
|
cute.copy(
|
|
gmem_tiled_copy_QKV,
|
|
tKgK[None, n, None, n_block],
|
|
tKsK[None, n, None],
|
|
pred=tKVpKV[None, n, None],
|
|
)
|
|
else:
|
|
# Clear the smem tiles to account for predicated off loads
|
|
tKsK[None, n, None].fill(0)
|
|
|
|
cute.arch.cp_async_commit_group()
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Softmax intermediate result: row_max and row_sum
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# shape: (atom_v_m * rest_m)
|
|
row_max = cute.make_fragment(
|
|
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
|
|
)
|
|
# shape: (atom_v_m * rest_m)
|
|
row_sum = cute.make_fragment(
|
|
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
|
|
)
|
|
row_max.fill(-cutlass.Float32.inf)
|
|
row_sum.fill(0.0)
|
|
|
|
# group parameters for compute_one_n_block
|
|
basic_params = SimpleNamespace(
|
|
m_block=m_block,
|
|
n_block=n_block,
|
|
mQ=mQ,
|
|
mK=mK,
|
|
batch_size=batch_size,
|
|
num_head=num_head,
|
|
)
|
|
mma_params = SimpleNamespace(
|
|
thr_mma=thr_mma,
|
|
tiled_mma=tiled_mma,
|
|
tSrQ=tSrQ,
|
|
tSrK=tSrK,
|
|
tOrVt=tOrVt,
|
|
acc_O=acc_O,
|
|
)
|
|
gmem_copy_params = SimpleNamespace(
|
|
gmem_tiled_copy_QKV=gmem_tiled_copy_QKV,
|
|
tKVcKV=tKVcKV,
|
|
tKgK=tKgK,
|
|
tKsK=tKsK,
|
|
tVgV=tVgV,
|
|
tVsV=tVsV,
|
|
tKVpKV=tKVpKV,
|
|
)
|
|
smem_copy_params = SimpleNamespace(
|
|
smem_tiled_copy_Q=smem_tiled_copy_Q,
|
|
smem_tiled_copy_K=smem_tiled_copy_K,
|
|
smem_tiled_copy_V=smem_tiled_copy_V,
|
|
tSsQ=tSsQ,
|
|
tSrQ_copy_view=tSrQ_copy_view,
|
|
tSsK=tSsK,
|
|
tSrK_copy_view=tSrK_copy_view,
|
|
tOsVt=tOsVt,
|
|
tOrVt_copy_view=tOrVt_copy_view,
|
|
)
|
|
softmax_params = SimpleNamespace(
|
|
row_max=row_max,
|
|
row_sum=row_sum,
|
|
softmax_scale_log2=softmax_scale_log2,
|
|
)
|
|
|
|
# Start processing of the first n-block.
|
|
# For performance reason, we separate out two kinds of iterations:
|
|
# those that need masking on S, and those that don't.
|
|
# We need masking on S for the very last block when K and V has length not multiple of n_block_size.
|
|
# We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks.
|
|
# We will have at least 1 "masking" iteration.
|
|
mask_steps = 1
|
|
if cutlass.const_expr(self._is_causal):
|
|
mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size)
|
|
|
|
for n_tile in cutlass.range_constexpr(mask_steps):
|
|
n_block = n_block_max - n_tile - 1
|
|
basic_params.n_block = n_block
|
|
if cutlass.const_expr(self._is_causal):
|
|
if n_block >= 0:
|
|
self.compute_one_n_block(
|
|
basic_params,
|
|
mma_params,
|
|
gmem_copy_params,
|
|
smem_copy_params,
|
|
softmax_params,
|
|
is_first_n_block=(n_tile == 0),
|
|
in_mask_steps=True,
|
|
)
|
|
else:
|
|
self.compute_one_n_block(
|
|
basic_params,
|
|
mma_params,
|
|
gmem_copy_params,
|
|
smem_copy_params,
|
|
softmax_params,
|
|
is_first_n_block=True,
|
|
in_mask_steps=True,
|
|
)
|
|
|
|
# Start async loads of rest k-tiles in reverse order, no k-residue handling needed
|
|
for n_tile in range(mask_steps, n_block_max, 1):
|
|
n_block = n_block_max - n_tile - 1
|
|
basic_params.n_block = n_block
|
|
self.compute_one_n_block(
|
|
basic_params,
|
|
mma_params,
|
|
gmem_copy_params,
|
|
smem_copy_params,
|
|
softmax_params,
|
|
is_first_n_block=False,
|
|
in_mask_steps=False,
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Epilogue
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# normalize acc_O by row_sum and calculate the lse
|
|
self.normalize_softmax(acc_O, row_sum)
|
|
# store acc_O
|
|
rO = cute.make_fragment_like(acc_O, self._dtype)
|
|
rO.store(acc_O.load().to(self._dtype))
|
|
# reuse sQ's data iterator
|
|
sO = cute.make_tensor(sQ.iterator, sO_layout)
|
|
|
|
# smem copy atom for O
|
|
smem_copy_atom_O = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(), self._dtype
|
|
)
|
|
# tiled copy atom for O
|
|
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
|
|
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
|
|
taccOrO = smem_thr_copy_O.retile(rO)
|
|
taccOsO = smem_thr_copy_O.partition_D(sO)
|
|
# copy acc O from rmem to smem with the smem copy atom
|
|
cute.copy(
|
|
smem_copy_atom_O,
|
|
taccOrO,
|
|
taccOsO,
|
|
)
|
|
gO = cute.local_tile(
|
|
mO[batch_size, None, num_head, None],
|
|
(self._m_block_size, self._head_dim_padded),
|
|
(m_block, 0),
|
|
)
|
|
|
|
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
|
tOsO = gmem_thr_copy_O.partition_S(sO)
|
|
tOgO = gmem_thr_copy_O.partition_D(gO)
|
|
tOrO = cute.make_fragment_like(tOgO, self._dtype)
|
|
# sync before all smem stores are done.
|
|
cute.arch.barrier()
|
|
# load acc O from smem to rmem for wider vectorization
|
|
cute.copy(
|
|
gmem_tiled_copy_O,
|
|
tOsO,
|
|
tOrO,
|
|
)
|
|
mcO = cute.make_identity_tensor(mO.layout.shape)
|
|
cO = cute.local_tile(
|
|
mcO[batch_size, None, num_head, None],
|
|
(self._m_block_size, self._head_dim_padded),
|
|
(m_block, 0),
|
|
)
|
|
tOcO = gmem_thr_copy_O.partition_D(cO)
|
|
tOpO = cute.make_fragment(
|
|
cute.make_layout(
|
|
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
|
|
stride=(tOgO.shape[2], 0, 1),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
|
|
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
|
|
tOpO[rest_v, 0, rest_n] = cute.elem_less(
|
|
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
|
|
)
|
|
# copy acc O from rmem to gmem
|
|
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
|
|
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
|
|
cute.copy(
|
|
gmem_tiled_copy_O,
|
|
tOrO[None, rest_m, None],
|
|
tOgO[None, rest_m, None],
|
|
pred=tOpO[None, rest_m, None],
|
|
)
|
|
|
|
@cute.jit
|
|
def compute_one_n_block(
|
|
self,
|
|
basic_params: SimpleNamespace,
|
|
mma_params: SimpleNamespace,
|
|
gmem_copy_params: SimpleNamespace,
|
|
smem_copy_params: SimpleNamespace,
|
|
softmax_params: SimpleNamespace,
|
|
is_first_n_block: cutlass.Constexpr,
|
|
in_mask_steps: cutlass.Constexpr,
|
|
):
|
|
"""Compute one n_block of S/O.
|
|
|
|
This function provides different variants for processing the first n block versus subsequent blocks,
|
|
as well as variants for handling masked and unmasked steps.
|
|
|
|
:param basic_params: basic parameters
|
|
:type basic_params: SimpleNamespace
|
|
:param mma_params: mma parameters
|
|
:type mma_params: SimpleNamespace
|
|
:param gmem_copy_params: gmem copy parameters
|
|
:type gmem_copy_params: SimpleNamespace
|
|
:param smem_copy_params: smem copy parameters
|
|
:type smem_copy_params: SimpleNamespace
|
|
:param softmax_params: softmax parameters
|
|
:type softmax_params: SimpleNamespace
|
|
:param is_first_n_block: is first n block
|
|
:type is_first_n_block: cutlass.Constexpr
|
|
"""
|
|
acc_shape_S = mma_params.thr_mma.partition_shape_C(
|
|
(self._m_block_size, self._n_block_size)
|
|
)
|
|
acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32)
|
|
acc_S.fill(0.0)
|
|
|
|
# wait for smem tile QK before mma calculation for S
|
|
cute.arch.cp_async_wait_group(0)
|
|
cute.arch.barrier()
|
|
# load smem tile V for O, special process for the first tile to avoid loading nan.
|
|
# The `if` here is a constexpr, won't be generated in the IR.
|
|
if is_first_n_block:
|
|
for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])):
|
|
if cute.elem_less(
|
|
gmem_copy_params.tKVcKV[0, n, 0][1],
|
|
basic_params.mK.layout.shape[1],
|
|
):
|
|
cute.copy(
|
|
gmem_copy_params.gmem_tiled_copy_QKV,
|
|
gmem_copy_params.tVgV[None, n, None, basic_params.n_block],
|
|
gmem_copy_params.tVsV[None, n, None],
|
|
pred=gmem_copy_params.tKVpKV[None, n, None],
|
|
)
|
|
else:
|
|
gmem_copy_params.tVsV[None, n, None].fill(0.0)
|
|
else:
|
|
cute.copy(
|
|
gmem_copy_params.gmem_tiled_copy_QKV,
|
|
gmem_copy_params.tVgV[None, None, None, basic_params.n_block],
|
|
gmem_copy_params.tVsV,
|
|
pred=gmem_copy_params.tKVpKV,
|
|
)
|
|
|
|
cute.arch.cp_async_commit_group()
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# S gemm calculation
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# load first QK k-block from smem to rmem for mma
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_Q,
|
|
smem_copy_params.tSsQ[None, None, 0],
|
|
smem_copy_params.tSrQ_copy_view[None, None, 0],
|
|
)
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_K,
|
|
smem_copy_params.tSsK[None, None, 0],
|
|
smem_copy_params.tSrK_copy_view[None, None, 0],
|
|
)
|
|
# mma for S
|
|
for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])):
|
|
# load next QK k-block from smem to rmem for mma
|
|
k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2])
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_Q,
|
|
smem_copy_params.tSsQ[None, None, k_next],
|
|
smem_copy_params.tSrQ_copy_view[None, None, k_next],
|
|
)
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_K,
|
|
smem_copy_params.tSsK[None, None, k_next],
|
|
smem_copy_params.tSrK_copy_view[None, None, k_next],
|
|
)
|
|
cute.gemm(
|
|
mma_params.tiled_mma,
|
|
acc_S,
|
|
mma_params.tSrQ[None, None, k],
|
|
mma_params.tSrK[None, None, k],
|
|
acc_S,
|
|
)
|
|
|
|
# wait for smem tile V for O
|
|
cute.arch.cp_async_wait_group(0)
|
|
cute.arch.barrier()
|
|
|
|
if basic_params.n_block > 0:
|
|
cute.copy(
|
|
gmem_copy_params.gmem_tiled_copy_QKV,
|
|
gmem_copy_params.tKgK[None, None, None, basic_params.n_block - 1],
|
|
gmem_copy_params.tKsK,
|
|
pred=gmem_copy_params.tKVpKV,
|
|
)
|
|
cute.arch.cp_async_commit_group()
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# online softmax
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
self.softmax_rescale_O(
|
|
basic_params,
|
|
mma_params,
|
|
softmax_params,
|
|
acc_S,
|
|
is_first_n_block,
|
|
in_mask_steps,
|
|
)
|
|
|
|
rP = cute.make_fragment_like(acc_S, self._dtype)
|
|
rP.store(acc_S.load().to(self._dtype))
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# O gemm calculation
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Convert layout of acc_S to gemm O accept layout.
|
|
# Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
|
rP_layout_divided = cute.logical_divide(rP.layout, (None, None, 2))
|
|
rP_mma_view = cute.make_layout(
|
|
(
|
|
(rP_layout_divided.shape[0], rP_layout_divided.shape[2][0]),
|
|
rP_layout_divided.shape[1],
|
|
rP_layout_divided.shape[2][1],
|
|
),
|
|
stride=(
|
|
(rP_layout_divided.stride[0], rP_layout_divided.stride[2][0]),
|
|
rP_layout_divided.stride[1],
|
|
rP_layout_divided.stride[2][1],
|
|
),
|
|
)
|
|
tOrS = cute.make_tensor(rP.iterator, rP_mma_view)
|
|
|
|
# load first V k-block from smem to rmem for mma
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_V,
|
|
smem_copy_params.tOsVt[None, None, 0],
|
|
smem_copy_params.tOrVt_copy_view[None, None, 0],
|
|
)
|
|
# mma for O
|
|
for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])):
|
|
# load next V k-block from smem to rmem for mma
|
|
k_next = (k + 1) % cute.size(tOrS.shape[2])
|
|
cute.copy(
|
|
smem_copy_params.smem_tiled_copy_V,
|
|
smem_copy_params.tOsVt[None, None, k_next],
|
|
smem_copy_params.tOrVt_copy_view[None, None, k_next],
|
|
)
|
|
cute.gemm(
|
|
mma_params.tiled_mma,
|
|
mma_params.acc_O,
|
|
tOrS[None, None, k],
|
|
mma_params.tOrVt[None, None, k],
|
|
mma_params.acc_O,
|
|
)
|
|
|
|
@cute.jit
|
|
def softmax_rescale_O(
|
|
self,
|
|
basic_params: SimpleNamespace,
|
|
mma_params: SimpleNamespace,
|
|
softmax_params: SimpleNamespace,
|
|
acc_S: cute.Tensor,
|
|
is_first_n_block: cutlass.Constexpr,
|
|
in_mask_steps: cutlass.Constexpr,
|
|
):
|
|
"""Apply online softmax and rescale acc_O.
|
|
|
|
This function provides different variants for processing the first n block versus subsequent blocks,
|
|
as well as variants for handling masked and unmasked steps.
|
|
|
|
:param basic_params: basic parameters
|
|
:type basic_params: SimpleNamespace
|
|
:param mma_params: mma parameters
|
|
:type mma_params: SimpleNamespace
|
|
:param softmax_params: softmax parameters
|
|
:type softmax_params: SimpleNamespace
|
|
:param acc_S: acc_S tensor
|
|
:type acc_S: cute.Tensor
|
|
:param is_first_n_block: is first n_block
|
|
:type is_first_n_block: cutlass.Constexpr
|
|
:param in_mask_steps: in mask steps
|
|
:type in_mask_steps: cutlass.Constexpr
|
|
"""
|
|
# Change acc_S to M,N layout view.
|
|
acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
|
|
acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O)
|
|
row_max_prev = None
|
|
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
|
|
if cutlass.const_expr(not is_first_n_block):
|
|
row_max_prev = cute.make_fragment_like(
|
|
softmax_params.row_max, cutlass.Float32
|
|
)
|
|
cute.basic_copy(softmax_params.row_max, row_max_prev)
|
|
# if it is the first tile, create a mask for residual of S to -inf for softmax.
|
|
tScS_mn = None
|
|
if cutlass.const_expr(in_mask_steps):
|
|
mcS = cute.make_identity_tensor(
|
|
(
|
|
basic_params.mQ.shape[0],
|
|
basic_params.mQ.shape[1],
|
|
basic_params.mQ.shape[2],
|
|
basic_params.mK.shape[1],
|
|
)
|
|
)
|
|
cS = cute.local_tile(
|
|
mcS[basic_params.batch_size, None, basic_params.num_head, None],
|
|
(self._m_block_size, self._n_block_size),
|
|
(basic_params.m_block, basic_params.n_block),
|
|
)
|
|
tScS = mma_params.thr_mma.partition_C(cS)
|
|
tScS_mn = self._make_acc_tensor_mn_view(tScS)
|
|
|
|
# Each iteration processes one row of acc_S
|
|
for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)):
|
|
# mask residual of S with -inf
|
|
if cutlass.const_expr(in_mask_steps):
|
|
if cutlass.const_expr(not self._is_causal):
|
|
# traverse column index.
|
|
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
|
|
if cute.elem_less(
|
|
basic_params.mK.shape[1], tScS_mn[0, c][3] + 1
|
|
):
|
|
acc_S_mn[r, c] = -cutlass.Float32.inf
|
|
else:
|
|
# get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
|
|
col_idx_limit = cutlass.min(
|
|
tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1]
|
|
)
|
|
# traverse column index.
|
|
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
|
|
# only consider the column index, so the row index sets to 0.
|
|
if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1):
|
|
acc_S_mn[r, c] = -cutlass.Float32.inf
|
|
|
|
# (n_block_size)
|
|
acc_S_row = acc_S_mn[r, None].load()
|
|
# row_max_cur_row => f32
|
|
row_max_cur_row = acc_S_row.reduce(
|
|
cute.ReductionOp.MAX, -cutlass.Float32.inf, 0
|
|
)
|
|
# quad reduction for row_max
|
|
row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row)
|
|
row_max_prev_row = None
|
|
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
|
|
if cutlass.const_expr(not is_first_n_block):
|
|
row_max_prev_row = row_max_prev[r]
|
|
row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row)
|
|
if cutlass.const_expr(self._is_causal):
|
|
row_max_cur_row = (
|
|
0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row
|
|
)
|
|
|
|
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
|
|
acc_S_row_exp = cute.TensorSSA(
|
|
self._exp2f(
|
|
acc_S_row * softmax_params.softmax_scale_log2
|
|
- row_max_cur_row * softmax_params.softmax_scale_log2
|
|
),
|
|
tuple(acc_S_row.shape),
|
|
cutlass.Float32,
|
|
)
|
|
# acc_S_row_sum => f32
|
|
acc_S_row_sum = acc_S_row_exp.reduce(
|
|
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
|
|
)
|
|
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
|
|
if cutlass.const_expr(not is_first_n_block):
|
|
prev_minus_cur_exp = self._exp2f(
|
|
row_max_prev_row * softmax_params.softmax_scale_log2
|
|
- row_max_cur_row * softmax_params.softmax_scale_log2
|
|
)
|
|
acc_S_row_sum = (
|
|
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
|
|
)
|
|
acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp
|
|
# update row_max, row_sum and acc_S
|
|
softmax_params.row_max[r] = row_max_cur_row
|
|
softmax_params.row_sum[r] = acc_S_row_sum
|
|
acc_S_mn[r, None] = acc_S_row_exp
|
|
|
|
@cute.jit
|
|
def normalize_softmax(
|
|
self,
|
|
acc_O: cute.Tensor,
|
|
row_sum: cute.Tensor,
|
|
):
|
|
"""Normalize acc_O by row_sum.
|
|
|
|
:param acc_O: input tensor
|
|
:type acc_O: cute.Tensor
|
|
:param row_sum: row_sum tensor
|
|
:type row_sum: cute.Tensor
|
|
"""
|
|
# do quad reduction for row_sum.
|
|
acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
|
|
for r in cutlass.range_constexpr(cute.size(row_sum)):
|
|
row_sum[r] = self._threadquad_reduce_sum(row_sum[r])
|
|
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
|
|
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
|
|
|
scale = (
|
|
1.0 if acc_O_mn_row_is_zero_or_nan else cute.arch.rcp_approx(row_sum[r])
|
|
)
|
|
|
|
acc_O_mn[r, None] = acc_O_mn[r, None].load() * scale
|
|
|
|
def _make_acc_tensor_mn_view(self, acc: cute.Tensor) -> cute.Tensor:
|
|
"""make acc tensor as mn layout view
|
|
|
|
:param acc: input tensor
|
|
:type acc: cute.Tensor
|
|
:return: acc tensor mn layout view
|
|
:rtype: cute.Tensor
|
|
"""
|
|
acc_layout_col_major = cute.make_layout(acc.layout.shape)
|
|
acc_layout_mn = cute.make_layout(
|
|
(
|
|
(
|
|
acc_layout_col_major.shape[0][1],
|
|
acc_layout_col_major.shape[1],
|
|
), # MMA_M
|
|
(
|
|
acc_layout_col_major.shape[0][0],
|
|
acc_layout_col_major.shape[2],
|
|
), # MMA_N
|
|
),
|
|
stride=(
|
|
(
|
|
acc_layout_col_major.stride[0][1],
|
|
acc_layout_col_major.stride[1],
|
|
), # MMA_M
|
|
(
|
|
acc_layout_col_major.stride[0][0],
|
|
acc_layout_col_major.stride[2],
|
|
), # MMA_N
|
|
),
|
|
)
|
|
acc_layout_mn = cute.composition(acc.layout, acc_layout_mn)
|
|
return cute.make_tensor(acc.iterator, acc_layout_mn)
|
|
|
|
def _threadquad_reduce(self, val: cutlass.Float32, op: Callable) -> cutlass.Float32:
|
|
"""thread quad reduction
|
|
|
|
:param val: register value
|
|
:type val: cutlass.Float32
|
|
:param op: binary operator
|
|
:type op: Callable
|
|
:return: reduced value
|
|
:rtype: cutlass.Float32
|
|
"""
|
|
val = op(
|
|
val,
|
|
cute.arch.shuffle_sync_bfly(val, offset=2, mask=-1, mask_and_clamp=31),
|
|
)
|
|
val = op(
|
|
val,
|
|
cute.arch.shuffle_sync_bfly(val, offset=1, mask=-1, mask_and_clamp=31),
|
|
)
|
|
return val
|
|
|
|
def _threadquad_reduce_max(self, val: cutlass.Float32) -> cutlass.Float32:
|
|
"""thread quad reduction max
|
|
|
|
:param val: register value
|
|
:type val: cutlass.Float32
|
|
:return: max value
|
|
:rtype: cutlass.Float32
|
|
"""
|
|
return self._threadquad_reduce(val, lambda x, y: cute.arch.fmax(x, y))
|
|
|
|
def _threadquad_reduce_sum(self, val: cutlass.Float32) -> cutlass.Float32:
|
|
"""thread quad reduction sum
|
|
|
|
:param val: register value
|
|
:type val: cutlass.Float32
|
|
:return: sum value
|
|
:rtype: cutlass.Float32
|
|
"""
|
|
return self._threadquad_reduce(val, lambda x, y: x + y)
|
|
|
|
def _exp2f(
|
|
self, x: Union[cute.TensorSSA, cutlass.Float32]
|
|
) -> Union[cute.TensorSSA, cutlass.Float32]:
|
|
"""exp2f calculation for both vector and scalar.
|
|
|
|
:param x: input value
|
|
:type x: cute.TensorSSA or cutlass.Float32
|
|
:return: exp2 value
|
|
:rtype: cute.TensorSSA or cutlass.Float32
|
|
"""
|
|
if isinstance(x, cute.TensorSSA):
|
|
res = cute.make_fragment(x.shape, cutlass.Float32)
|
|
res.store(x)
|
|
|
|
for i in range(cute.size(x.shape)):
|
|
res[i] = self._exp2f(res[i])
|
|
|
|
return res.load()
|
|
return cute.arch.exp2(x)
|
|
|
|
|
|
def run(
|
|
dtype: Type[cutlass.Numeric],
|
|
batch_size: int,
|
|
seqlen_q: int,
|
|
seqlen_k: int,
|
|
num_head: int,
|
|
head_dim: int,
|
|
softmax_scale: float = 1.0,
|
|
m_block_size: int = 128,
|
|
n_block_size: int = 128,
|
|
num_threads: int = 128,
|
|
is_causal: bool = False,
|
|
warmup_iterations: int = 0,
|
|
iterations: int = 1,
|
|
skip_ref_check: bool = False,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
# Skip unsupported testcase
|
|
if not FlashAttentionForwardAmpere.can_implement(
|
|
dtype,
|
|
head_dim,
|
|
m_block_size,
|
|
n_block_size,
|
|
num_threads,
|
|
is_causal,
|
|
):
|
|
raise TypeError(
|
|
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
|
|
)
|
|
|
|
print(f"Running Ampere SM80 FlashAttentionForward test with:")
|
|
print(f" dtype: {dtype}")
|
|
print(f" batch_size: {batch_size}")
|
|
print(f" seqlen_q: {seqlen_q}")
|
|
print(f" seqlen_k: {seqlen_k}")
|
|
print(f" num_head: {num_head}")
|
|
print(f" head_dim: {head_dim}")
|
|
print(f" softmax_scale: {softmax_scale}")
|
|
print(f" m_block_size: {m_block_size}")
|
|
print(f" n_block_size: {n_block_size}")
|
|
print(f" num_threads: {num_threads}")
|
|
print(f" is_causal: {is_causal}")
|
|
print(f" warmup_iterations: {warmup_iterations}")
|
|
print(f" iterations: {iterations}")
|
|
print(f" skip_ref_check: {skip_ref_check}")
|
|
print(f" use_cold_l2: {use_cold_l2}")
|
|
|
|
# Create tensor Q/K/V/O
|
|
def create_tensor(
|
|
batch_size: int,
|
|
seqlen: int,
|
|
num_head: int,
|
|
head_dim: int,
|
|
dtype: Type[cutlass.Numeric],
|
|
) -> cute.Tensor:
|
|
# (batch_size, seqlen, num_head, head_dim)
|
|
shape = (batch_size, seqlen, num_head, head_dim)
|
|
torch_tensor = (
|
|
torch.empty(*shape, dtype=torch.int32)
|
|
.random_(-2, 2)
|
|
.to(dtype=cutlass_torch.dtype(dtype))
|
|
.cuda()
|
|
)
|
|
# assume input is 16B aligned.
|
|
cute_tensor = (
|
|
from_dlpack(torch_tensor, assumed_align=16)
|
|
.mark_layout_dynamic(leading_dim=3)
|
|
.mark_compact_shape_dynamic(
|
|
mode=3,
|
|
stride_order=torch_tensor.dim_order(),
|
|
divisibility=(128 // dtype.width),
|
|
)
|
|
)
|
|
return cute_tensor, torch_tensor
|
|
|
|
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
|
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
|
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
|
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
|
|
|
fa2_fwd = FlashAttentionForwardAmpere(
|
|
head_dim,
|
|
m_block_size,
|
|
n_block_size,
|
|
num_threads,
|
|
is_causal,
|
|
)
|
|
|
|
# Get current CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.current_stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
# compile the fa2 forward pass
|
|
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
|
|
|
|
if not skip_ref_check:
|
|
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
|
|
torch.cuda.synchronize()
|
|
q_ref = q_torch.permute(0, 2, 1, 3)
|
|
k_ref = k_torch.permute(0, 2, 1, 3)
|
|
v_ref = v_torch.permute(0, 2, 1, 3)
|
|
torch.backends.cuda.enable_flash_sdp(enabled=True)
|
|
ref_o = torch.nn.functional.scaled_dot_product_attention(
|
|
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
|
|
).permute(0, 2, 1, 3)
|
|
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
|
|
print("Results verified successfully!")
|
|
|
|
def generate_tensors():
|
|
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
|
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
|
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
|
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
|
return testing.JitArguments(
|
|
q_workspace,
|
|
k_workspace,
|
|
v_workspace,
|
|
o_workspace,
|
|
softmax_scale,
|
|
current_stream,
|
|
)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
q_torch.numel() * q_torch.element_size()
|
|
+ k_torch.numel() * k_torch.element_size()
|
|
+ v_torch.numel() * v_torch.element_size()
|
|
+ o_torch.numel() * o_torch.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
avg_time_us = testing.benchmark(
|
|
compiled_fa2_fwd,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=current_stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return avg_time_us # Return execution time in microseconds
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="example of flash attention v2 with CuTe on GPU"
|
|
)
|
|
parser.add_argument("--dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--seqlen_q", type=int, default=8192)
|
|
parser.add_argument("--seqlen_k", type=int, default=8192)
|
|
parser.add_argument("--num_head", type=int, default=16)
|
|
parser.add_argument("--head_dim", type=int, default=128)
|
|
parser.add_argument("--softmax_scale", type=float, default=0.5)
|
|
parser.add_argument("--m_block_size", type=int, default=128)
|
|
parser.add_argument("--n_block_size", type=int, default=64)
|
|
parser.add_argument("--num_threads", type=int, default=128)
|
|
parser.add_argument("--is_causal", action="store_true", help="Enable causal mask")
|
|
parser.add_argument("--warmup_iterations", type=int, default=3)
|
|
parser.add_argument("--iterations", type=int, default=10)
|
|
parser.add_argument(
|
|
"--skip_ref_check", action="store_true", help="Skip reference check"
|
|
)
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
run(
|
|
args.dtype,
|
|
args.batch_size,
|
|
args.seqlen_q,
|
|
args.seqlen_k,
|
|
args.num_head,
|
|
args.head_dim,
|
|
args.softmax_scale,
|
|
args.m_block_size,
|
|
args.n_block_size,
|
|
args.num_threads,
|
|
args.is_causal,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
|
|
print("PASS")
|