Files
cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py
2025-07-21 22:03:55 -04:00

1358 lines
54 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from types import SimpleNamespace
from typing import Type, Union, Callable
import torch
import cuda.bindings.driver as cuda
import cutlass.cute.testing as testing
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, warp
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
import cutlass.utils as utils
"""
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
- Matrix Q is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension
- Matrix K is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension
- Matrix V is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension
- Matrix O is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension
This kernel supports the following features:
- Utilizes CpAsync for efficient memory operations
- Utilizes Ampere's tensor core for matrix multiply-accumulate (MMA) operations
- Utilizes register pipeline to overlap shared memory-to-register transfers with computations.
- Leverages DSL to implement an integrated online softmax fusion pattern.
This kernel works as follows:
1. Load Q and K matrices from global memory (GMEM) to shared memory (SMEM) using CpAsync operations.
2. Perform matrix multiply-accumulate (MMA) operations using tensor core instructions to compute intermediate result S.
3. Apply padding mask or causal mask to S during initial iterations.
4. Apply online softmax to S and rescale O using results from previous iteration.
5. Load V matrices and perform matrix multiply-accumulate (MMA) operations to compute final result O.
6. Normalize O after all iterations complete and store result back to global memory (GMEM).
To run this example:
.. code-block:: bash
python examples/ampere/flash_attention_v2.py \
--dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \
--num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \
--num_head 16 --softmax_scale 1.0 --is_causal
The above command configures the model to use float16 for inputs and outputs. The problem dimensions
are set to: batch size of 1, query sequence length of 1280, key sequence length of 1536, head dimension
of 128, and 16 attention heads. The softmax scale is set to 1.0 and causal masking is enabled. The computation
uses tiles of size 128x128 for m and n dimensions, and utilizes 128 parallel threads.
To collect the performance with NCU profiler:
.. code-block:: bash
ncu python examples/ampere/flash_attention_v2.py \
--dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \
--num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \
--num_head 16 --softmax_scale 1.0 --is_causal --skip_ref_check
There are some constraints for this example:
* Only fp16 and bf16 data types are supported.
* The contiguous dimension of each tensor must be at least 16 bytes aligned.
* The log-sum-exp(for training) is not computed in the kernel.
* The values of `m_block_size`, `n_block_size`, and `head_dim` must be selected to stay within shared memory capacity limits.
* `m_block_size * 2` must be divisible by `num_threads`, otherwise the kernel will not be able to get the correct result.
"""
class FlashAttentionForwardAmpere:
def __init__(
self,
head_dim: int,
m_block_size: int = 128,
n_block_size: int = 128,
num_threads: int = 128,
is_causal: bool = False,
):
"""Initializes the configuration for a flash attention v2 kernel.
All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
should be a multiple of 8.
:param head_dim: head dimension
:type head_dim: int
:param m_block_size: m block size
:type m_block_size: int
:param n_block_size: n block size
:type n_block_size: int
:param num_threads: number of threads
:type num_threads: int
:param is_causal: is causal
"""
self._head_dim = head_dim
self._m_block_size = m_block_size
self._n_block_size = n_block_size
# padding head_dim to a multiple of 32 as k_block_size
self._head_dim_padded = (head_dim + 31) // 32 * 32
self._num_threads = num_threads
self._is_causal = is_causal
@staticmethod
def can_implement(
dtype, head_dim, m_block_size, n_block_size, num_threads, is_causal
) -> bool:
"""Check if the kernel can be implemented with the given parameters.
:param dtype: data type
:type dtype: cutlass.Numeric
:param head_dim: head dimension
:type head_dim: int
:param m_block_size: m block size
:type m_block_size: int
:param n_block_size: n block size
:type n_block_size: int
:param num_threads: number of threads
:type num_threads: int
:param is_causal: is causal
:type is_causal: bool
:return: True if the kernel can be implemented, False otherwise
:rtype: bool
"""
# Check if data type is fp16 or bf16
if dtype != cutlass.Float16 and dtype != cutlass.BFloat16:
return False
# Check if head dimension is a multiple of 8
if head_dim % 8 != 0:
return False
# Check if number of threads is a multiple of 32
if num_threads % 32 != 0:
return False
# Check if block size setting is out of shared memory capacity
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
if smem_usage > smem_capacity:
return False
# Check if twice the block size is divisible by the number of threads
if (m_block_size * 2) % num_threads != 0:
return False
return True
@cute.jit
def __call__(
self,
mQ: cute.Tensor,
mK: cute.Tensor,
mV: cute.Tensor,
mO: cute.Tensor,
softmax_scale: cutlass.Float32,
stream: cuda.CUstream,
):
"""Configures and launches the flash attention v2 kernel.
mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
(batch_size, seqlen_q, num_head, head_dim):(seqlen_q * num_head * head_dim, num_head * head_dim, head_dim, 1)
Prepares the shared memory layout, tiled copy atoms, tiled mma and shared memory storage.
Then launches the kernel function with the prepared parameters.
:param mQ: query tensor
:type mQ: cute.Tensor
:param mK: key tensor
:type mK: cute.Tensor
:param mV: value tensor
:type mV: cute.Tensor
:param mO: output tensor
:type mO: cute.Tensor
:param softmax_scale: softmax scale
:type softmax_scale: cutlass.Float32
"""
# Get the data type and check if it is fp16 or bf16
if cutlass.const_expr(
not (
mQ.element_type == mK.element_type == mV.element_type == mO.element_type
)
):
raise TypeError("All tensors must have the same data type")
if cutlass.const_expr(
not (
mQ.element_type == cutlass.Float16
or mQ.element_type == cutlass.BFloat16
)
):
raise TypeError("Only Float16 or BFloat16 is supported")
self._dtype: Type[cutlass.Numeric] = mQ.element_type
# ///////////////////////////////////////////////////////////////////////////////
# Shared memory layout: Q/K/V
# ///////////////////////////////////////////////////////////////////////////////
smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32
swizzle_bits = 3 if smem_k_block_size == 64 else 2
sQ_layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)),
)
sQ_layout = cute.tile_to_shape(
sQ_layout_atom,
(self._m_block_size, self._head_dim_padded),
(0, 1),
)
sKV_layout_atom = sQ_layout_atom
sKV_layout = cute.tile_to_shape(
sKV_layout_atom,
(self._n_block_size, self._head_dim_padded),
(0, 1),
)
sO_layout = sQ_layout
@cute.struct
class SharedStorage:
sQ: cute.struct.Align[
cute.struct.MemRange[self._dtype, cute.cosize(sQ_layout)], 1024
]
sK: cute.struct.Align[
cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024
]
sV: cute.struct.Align[
cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024
]
# ///////////////////////////////////////////////////////////////////////////////
# GMEM Tiled copy:
# ///////////////////////////////////////////////////////////////////////////////
# Thread layouts for copies
universal_copy_bits = 128
async_copy_elems = universal_copy_bits // self._dtype.width
# atom_async_copy: async copy atom for QKV load
atom_async_copy = cute.make_copy_atom(
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
self._dtype,
num_bits_per_copy=universal_copy_bits,
)
# atom_universal_copy: universal copy atom for O store
atom_universal_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self._dtype,
num_bits_per_copy=universal_copy_bits,
)
# tQKV_layout: thread layout for QKV load
tQKV_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
tQKV_layout = cute.make_layout(
(self._num_threads // tQKV_shape_dim_1, tQKV_shape_dim_1),
stride=(tQKV_shape_dim_1, 1),
)
# tO_layout: thread layout for O store
tO_layout = tQKV_layout
# Value layouts for copies
vQKV_layout = cute.make_layout((1, async_copy_elems))
vO_layout = vQKV_layout
# gmem_tiled_copy_QKV: tiled copy for QKV load
gmem_tiled_copy_QKV = cute.make_tiled_copy_tv(
atom_async_copy, tQKV_layout, vQKV_layout
)
# gmem_tiled_copy_O: tiled copy for O store
gmem_tiled_copy_O = cute.make_tiled_copy_tv(
atom_universal_copy, tO_layout, vO_layout
)
# ///////////////////////////////////////////////////////////////////////////////
# Tiled mma
# ///////////////////////////////////////////////////////////////////////////////
tiled_mma = cute.make_tiled_mma(
warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)),
(self._num_threads // 32, 1, 1),
permutation_mnk=(self._num_threads // 32 * 16, 16, 16),
)
# grid_dim: (m_block, batch_size, num_head)
grid_dim = (
cute.ceil_div(mQ.shape[1], self._m_block_size),
cute.size(mQ.shape[0]),
cute.size(mQ.shape[2]),
)
LOG2_E = 1.4426950408889634074
softmax_scale_log2 = softmax_scale * LOG2_E
self.kernel(
mQ,
mK,
mV,
mO,
softmax_scale_log2,
sQ_layout,
sKV_layout,
sO_layout,
gmem_tiled_copy_QKV,
gmem_tiled_copy_O,
tiled_mma,
SharedStorage,
).launch(
grid=grid_dim,
block=[self._num_threads, 1, 1],
smem=SharedStorage.size_in_bytes(),
stream=stream,
)
@cute.kernel
def kernel(
self,
mQ: cute.Tensor,
mK: cute.Tensor,
mV: cute.Tensor,
mO: cute.Tensor,
softmax_scale_log2: cutlass.Float32,
sQ_layout: cute.ComposedLayout,
sKV_layout: cute.ComposedLayout,
sO_layout: cute.ComposedLayout,
gmem_tiled_copy_QKV: cute.TiledCopy,
gmem_tiled_copy_O: cute.TiledCopy,
tiled_mma: cute.TiledMma,
SharedStorage: cutlass.Constexpr,
):
"""Kernel function for flash attention v2.
:param mQ: query tensor
:type mQ: cute.Tensor
:param mK: key tensor
:type mK: cute.Tensor
:param mV: value tensor
:type mV: cute.Tensor
:param mO: output tensor
:type mO: cute.Tensor
:param softmax_scale_log2: softmax scale log2
:type softmax_scale_log2: cutlass.Float32
:param sQ_layout: query layout
:type sQ_layout: cute.ComposedLayout
:param sKV_layout: key/value layout
:type sKV_layout: cute.ComposedLayout
:param sO_layout: output layout
:type sO_layout: cute.ComposedLayout
:param gmem_tiled_copy_QKV: tiled copy for QKV load
:type gmem_tiled_copy_QKV: cute.TiledCopy
:param gmem_tiled_copy_O: tiled copy for O store
:type gmem_tiled_copy_O: cute.TiledCopy
:param tiled_mma: tiled mma
:type tiled_mma: cute.TiledMma
:param SharedStorage: shared storage
:type SharedStorage: cutlass.Constexpr
"""
# Thread index, block index
tidx, _, _ = cute.arch.thread_idx()
m_block, batch_size, num_head = cute.arch.block_idx()
n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size)
if self._is_causal:
n_block_max = min(
cute.ceil_div(
(m_block + 1) * self._m_block_size,
self._n_block_size,
),
n_block_max,
)
n_block = n_block_max - 1
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# ///////////////////////////////////////////////////////////////////////////////
# (m_block_size, head_dim)
gQ = cute.local_tile(
mQ[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
# (n_block_size, head_dim, n_block)
gK = cute.local_tile(
mK[batch_size, None, num_head, None],
(self._n_block_size, self._head_dim_padded),
(None, 0),
)
# (n_block_size, head_dim, n_block)
gV = cute.local_tile(
mV[batch_size, None, num_head, None],
(self._n_block_size, self._head_dim_padded),
(None, 0),
)
# ///////////////////////////////////////////////////////////////////////////////
# Get shared memory buffer
# ///////////////////////////////////////////////////////////////////////////////
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sQ = storage.sQ.get_tensor(sQ_layout)
sK = storage.sK.get_tensor(sKV_layout)
sV = storage.sV.get_tensor(sKV_layout)
# Transpose view of V to tensor with layout (head_dim, n_block_size) for tiled mma
sVt = cute.composition(
sV,
cute.make_layout(
(self._head_dim_padded, self._n_block_size),
stride=(self._n_block_size, 1),
),
)
gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_slice(tidx)
# (CPY_Atom, CPY_M, CPY_K)
tQgQ = gmem_thr_copy_QKV.partition_S(gQ)
tQsQ = gmem_thr_copy_QKV.partition_D(sQ)
# (CPY_Atom, CPY_N, CPY_K, n_block)
tKgK = gmem_thr_copy_QKV.partition_S(gK)
tKsK = gmem_thr_copy_QKV.partition_D(sK)
# (CPY_Atom, CPY_N, CPY_K, n_block)
tVgV = gmem_thr_copy_QKV.partition_S(gV)
tVsV = gmem_thr_copy_QKV.partition_D(sV)
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
# ///////////////////////////////////////////////////////////////////////////////
thr_mma = tiled_mma.get_slice(tidx)
tSrQ = thr_mma.make_fragment_A(thr_mma.partition_A(sQ))
tSrK = thr_mma.make_fragment_B(thr_mma.partition_B(sK))
tOrVt = thr_mma.make_fragment_B(thr_mma.partition_B(sVt))
acc_shape_O = thr_mma.partition_shape_C(
(self._m_block_size, self._head_dim_padded)
)
acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32)
acc_O.fill(0.0)
# ///////////////////////////////////////////////////////////////////////////////
# Smem copy atom tiling
# ///////////////////////////////////////////////////////////////////////////////
smem_copy_atom_Q = cute.make_copy_atom(
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
self._dtype,
)
smem_copy_atom_K = cute.make_copy_atom(
warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),
self._dtype,
)
smem_copy_atom_V = cute.make_copy_atom(
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
self._dtype,
)
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
smem_thr_copy_V = smem_tiled_copy_V.get_slice(tidx)
tSsQ = smem_thr_copy_Q.partition_S(sQ)
tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ)
tSsK = smem_thr_copy_K.partition_S(sK)
tSrK_copy_view = smem_thr_copy_K.retile(tSrK)
tOsVt = smem_thr_copy_V.partition_S(sVt)
tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt)
# ///////////////////////////////////////////////////////////////////////////////
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
# of tile_shape
# ///////////////////////////////////////////////////////////////////////////////
# Construct identity layout for Q and KV
mcQ = cute.make_identity_tensor(mQ.layout.shape)
mcKV = cute.make_identity_tensor(mK.layout.shape)
cQ = cute.local_tile(
mcQ[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
cKV = cute.local_tile(
mcKV[batch_size, None, num_head, None],
(self._n_block_size, self._head_dim_padded),
(n_block, 0),
)
# Repeat the partitioning with identity layouts
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
# Allocate predicate tensors for m and n, here we only allocate the tile of k, and do special process for mn.
# This is to reduce register pressure and gets 2-3% performance gain compared with allocating the whole tile.
tQpQ = cute.make_fragment(
cute.make_layout(
(
tQsQ.shape[0][1],
cute.size(tQsQ, mode=[1]),
cute.size(tQsQ, mode=[2]),
),
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
tKVpKV = cute.make_fragment(
cute.make_layout(
(
tKsK.shape[0][1],
cute.size(tKsK, mode=[1]),
cute.size(tKsK, mode=[2]),
),
stride=(cute.size(tKsK, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
# Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile.
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
)
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
)
# ///////////////////////////////////////////////////////////////////////////////
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Start async loads of the last mn-tile, where we take care of the mn residue
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
tQgQ[None, m, None],
tQsQ[None, m, None],
pred=tQpQ[None, m, None],
)
else:
# Clear the smem tiles to account for predicated off loads
tQsQ[None, m, None].fill(0)
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
tKgK[None, n, None, n_block],
tKsK[None, n, None],
pred=tKVpKV[None, n, None],
)
else:
# Clear the smem tiles to account for predicated off loads
tKsK[None, n, None].fill(0)
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# Softmax intermediate result: row_max and row_sum
# ///////////////////////////////////////////////////////////////////////////////
# shape: (atom_v_m * rest_m)
row_max = cute.make_fragment(
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
)
# shape: (atom_v_m * rest_m)
row_sum = cute.make_fragment(
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
)
row_max.fill(-cutlass.Float32.inf)
row_sum.fill(0.0)
# group parameters for compute_one_n_block
basic_params = SimpleNamespace(
m_block=m_block,
n_block=n_block,
mQ=mQ,
mK=mK,
batch_size=batch_size,
num_head=num_head,
)
mma_params = SimpleNamespace(
thr_mma=thr_mma,
tiled_mma=tiled_mma,
tSrQ=tSrQ,
tSrK=tSrK,
tOrVt=tOrVt,
acc_O=acc_O,
)
gmem_copy_params = SimpleNamespace(
gmem_tiled_copy_QKV=gmem_tiled_copy_QKV,
tKVcKV=tKVcKV,
tKgK=tKgK,
tKsK=tKsK,
tVgV=tVgV,
tVsV=tVsV,
tKVpKV=tKVpKV,
)
smem_copy_params = SimpleNamespace(
smem_tiled_copy_Q=smem_tiled_copy_Q,
smem_tiled_copy_K=smem_tiled_copy_K,
smem_tiled_copy_V=smem_tiled_copy_V,
tSsQ=tSsQ,
tSrQ_copy_view=tSrQ_copy_view,
tSsK=tSsK,
tSrK_copy_view=tSrK_copy_view,
tOsVt=tOsVt,
tOrVt_copy_view=tOrVt_copy_view,
)
softmax_params = SimpleNamespace(
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
)
# Start processing of the first n-block.
# For performance reason, we separate out two kinds of iterations:
# those that need masking on S, and those that don't.
# We need masking on S for the very last block when K and V has length not multiple of n_block_size.
# We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks.
# We will have at least 1 "masking" iteration.
mask_steps = 1
if cutlass.const_expr(self._is_causal):
mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size)
for n_tile in cutlass.range_constexpr(mask_steps):
n_block = n_block_max - n_tile - 1
basic_params.n_block = n_block
if cutlass.const_expr(self._is_causal):
if n_block >= 0:
self.compute_one_n_block(
basic_params,
mma_params,
gmem_copy_params,
smem_copy_params,
softmax_params,
is_first_n_block=(n_tile == 0),
in_mask_steps=True,
)
else:
self.compute_one_n_block(
basic_params,
mma_params,
gmem_copy_params,
smem_copy_params,
softmax_params,
is_first_n_block=True,
in_mask_steps=True,
)
# Start async loads of rest k-tiles in reverse order, no k-residue handling needed
for n_tile in range(mask_steps, n_block_max, 1):
n_block = n_block_max - n_tile - 1
basic_params.n_block = n_block
self.compute_one_n_block(
basic_params,
mma_params,
gmem_copy_params,
smem_copy_params,
softmax_params,
is_first_n_block=False,
in_mask_steps=False,
)
# ///////////////////////////////////////////////////////////////////////////////
# Epilogue
# ///////////////////////////////////////////////////////////////////////////////
# normalize acc_O by row_sum and calculate the lse
self.normalize_softmax(acc_O, row_sum)
# store acc_O
rO = cute.make_fragment_like(acc_O, self._dtype)
rO.store(acc_O.load().to(self._dtype))
# reuse sQ's data iterator
sO = cute.make_tensor(sQ.iterator, sO_layout)
# smem copy atom for O
smem_copy_atom_O = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), self._dtype
)
# tiled copy atom for O
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
# copy acc O from rmem to smem with the smem copy atom
cute.copy(
smem_copy_atom_O,
taccOrO,
taccOsO,
)
gO = cute.local_tile(
mO[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
tOsO = gmem_thr_copy_O.partition_S(sO)
tOgO = gmem_thr_copy_O.partition_D(gO)
tOrO = cute.make_fragment_like(tOgO, self._dtype)
# sync before all smem stores are done.
cute.arch.barrier()
# load acc O from smem to rmem for wider vectorization
cute.copy(
gmem_tiled_copy_O,
tOsO,
tOrO,
)
mcO = cute.make_identity_tensor(mO.layout.shape)
cO = cute.local_tile(
mcO[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
tOcO = gmem_thr_copy_O.partition_D(cO)
tOpO = cute.make_fragment(
cute.make_layout(
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
stride=(tOgO.shape[2], 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
tOpO[rest_v, 0, rest_n] = cute.elem_less(
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
)
# copy acc O from rmem to gmem
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
cute.copy(
gmem_tiled_copy_O,
tOrO[None, rest_m, None],
tOgO[None, rest_m, None],
pred=tOpO[None, rest_m, None],
)
@cute.jit
def compute_one_n_block(
self,
basic_params: SimpleNamespace,
mma_params: SimpleNamespace,
gmem_copy_params: SimpleNamespace,
smem_copy_params: SimpleNamespace,
softmax_params: SimpleNamespace,
is_first_n_block: cutlass.Constexpr,
in_mask_steps: cutlass.Constexpr,
):
"""Compute one n_block of S/O.
This function provides different variants for processing the first n block versus subsequent blocks,
as well as variants for handling masked and unmasked steps.
:param basic_params: basic parameters
:type basic_params: SimpleNamespace
:param mma_params: mma parameters
:type mma_params: SimpleNamespace
:param gmem_copy_params: gmem copy parameters
:type gmem_copy_params: SimpleNamespace
:param smem_copy_params: smem copy parameters
:type smem_copy_params: SimpleNamespace
:param softmax_params: softmax parameters
:type softmax_params: SimpleNamespace
:param is_first_n_block: is first n block
:type is_first_n_block: cutlass.Constexpr
"""
acc_shape_S = mma_params.thr_mma.partition_shape_C(
(self._m_block_size, self._n_block_size)
)
acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32)
acc_S.fill(0.0)
# wait for smem tile QK before mma calculation for S
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
# load smem tile V for O, special process for the first tile to avoid loading nan.
# The `if` here is a constexpr, won't be generated in the IR.
if is_first_n_block:
for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])):
if cute.elem_less(
gmem_copy_params.tKVcKV[0, n, 0][1],
basic_params.mK.layout.shape[1],
):
cute.copy(
gmem_copy_params.gmem_tiled_copy_QKV,
gmem_copy_params.tVgV[None, n, None, basic_params.n_block],
gmem_copy_params.tVsV[None, n, None],
pred=gmem_copy_params.tKVpKV[None, n, None],
)
else:
gmem_copy_params.tVsV[None, n, None].fill(0.0)
else:
cute.copy(
gmem_copy_params.gmem_tiled_copy_QKV,
gmem_copy_params.tVgV[None, None, None, basic_params.n_block],
gmem_copy_params.tVsV,
pred=gmem_copy_params.tKVpKV,
)
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# S gemm calculation
# ///////////////////////////////////////////////////////////////////////////////
# load first QK k-block from smem to rmem for mma
cute.copy(
smem_copy_params.smem_tiled_copy_Q,
smem_copy_params.tSsQ[None, None, 0],
smem_copy_params.tSrQ_copy_view[None, None, 0],
)
cute.copy(
smem_copy_params.smem_tiled_copy_K,
smem_copy_params.tSsK[None, None, 0],
smem_copy_params.tSrK_copy_view[None, None, 0],
)
# mma for S
for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])):
# load next QK k-block from smem to rmem for mma
k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2])
cute.copy(
smem_copy_params.smem_tiled_copy_Q,
smem_copy_params.tSsQ[None, None, k_next],
smem_copy_params.tSrQ_copy_view[None, None, k_next],
)
cute.copy(
smem_copy_params.smem_tiled_copy_K,
smem_copy_params.tSsK[None, None, k_next],
smem_copy_params.tSrK_copy_view[None, None, k_next],
)
cute.gemm(
mma_params.tiled_mma,
acc_S,
mma_params.tSrQ[None, None, k],
mma_params.tSrK[None, None, k],
acc_S,
)
# wait for smem tile V for O
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
if basic_params.n_block > 0:
cute.copy(
gmem_copy_params.gmem_tiled_copy_QKV,
gmem_copy_params.tKgK[None, None, None, basic_params.n_block - 1],
gmem_copy_params.tKsK,
pred=gmem_copy_params.tKVpKV,
)
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# online softmax
# ///////////////////////////////////////////////////////////////////////////////
self.softmax_rescale_O(
basic_params,
mma_params,
softmax_params,
acc_S,
is_first_n_block,
in_mask_steps,
)
rP = cute.make_fragment_like(acc_S, self._dtype)
rP.store(acc_S.load().to(self._dtype))
# ///////////////////////////////////////////////////////////////////////////////
# O gemm calculation
# ///////////////////////////////////////////////////////////////////////////////
# Convert layout of acc_S to gemm O accept layout.
# Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
rP_layout_divided = cute.logical_divide(rP.layout, (None, None, 2))
rP_mma_view = cute.make_layout(
(
(rP_layout_divided.shape[0], rP_layout_divided.shape[2][0]),
rP_layout_divided.shape[1],
rP_layout_divided.shape[2][1],
),
stride=(
(rP_layout_divided.stride[0], rP_layout_divided.stride[2][0]),
rP_layout_divided.stride[1],
rP_layout_divided.stride[2][1],
),
)
tOrS = cute.make_tensor(rP.iterator, rP_mma_view)
# load first V k-block from smem to rmem for mma
cute.copy(
smem_copy_params.smem_tiled_copy_V,
smem_copy_params.tOsVt[None, None, 0],
smem_copy_params.tOrVt_copy_view[None, None, 0],
)
# mma for O
for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])):
# load next V k-block from smem to rmem for mma
k_next = (k + 1) % cute.size(tOrS.shape[2])
cute.copy(
smem_copy_params.smem_tiled_copy_V,
smem_copy_params.tOsVt[None, None, k_next],
smem_copy_params.tOrVt_copy_view[None, None, k_next],
)
cute.gemm(
mma_params.tiled_mma,
mma_params.acc_O,
tOrS[None, None, k],
mma_params.tOrVt[None, None, k],
mma_params.acc_O,
)
@cute.jit
def softmax_rescale_O(
self,
basic_params: SimpleNamespace,
mma_params: SimpleNamespace,
softmax_params: SimpleNamespace,
acc_S: cute.Tensor,
is_first_n_block: cutlass.Constexpr,
in_mask_steps: cutlass.Constexpr,
):
"""Apply online softmax and rescale acc_O.
This function provides different variants for processing the first n block versus subsequent blocks,
as well as variants for handling masked and unmasked steps.
:param basic_params: basic parameters
:type basic_params: SimpleNamespace
:param mma_params: mma parameters
:type mma_params: SimpleNamespace
:param softmax_params: softmax parameters
:type softmax_params: SimpleNamespace
:param acc_S: acc_S tensor
:type acc_S: cute.Tensor
:param is_first_n_block: is first n_block
:type is_first_n_block: cutlass.Constexpr
:param in_mask_steps: in mask steps
:type in_mask_steps: cutlass.Constexpr
"""
# Change acc_S to M,N layout view.
acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O)
row_max_prev = None
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
if cutlass.const_expr(not is_first_n_block):
row_max_prev = cute.make_fragment_like(
softmax_params.row_max, cutlass.Float32
)
cute.basic_copy(softmax_params.row_max, row_max_prev)
# if it is the first tile, create a mask for residual of S to -inf for softmax.
tScS_mn = None
if cutlass.const_expr(in_mask_steps):
mcS = cute.make_identity_tensor(
(
basic_params.mQ.shape[0],
basic_params.mQ.shape[1],
basic_params.mQ.shape[2],
basic_params.mK.shape[1],
)
)
cS = cute.local_tile(
mcS[basic_params.batch_size, None, basic_params.num_head, None],
(self._m_block_size, self._n_block_size),
(basic_params.m_block, basic_params.n_block),
)
tScS = mma_params.thr_mma.partition_C(cS)
tScS_mn = self._make_acc_tensor_mn_view(tScS)
# Each iteration processes one row of acc_S
for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)):
# mask residual of S with -inf
if cutlass.const_expr(in_mask_steps):
if cutlass.const_expr(not self._is_causal):
# traverse column index.
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
if cute.elem_less(
basic_params.mK.shape[1], tScS_mn[0, c][3] + 1
):
acc_S_mn[r, c] = -cutlass.Float32.inf
else:
# get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
col_idx_limit = cutlass.min(
tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1]
)
# traverse column index.
for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])):
# only consider the column index, so the row index sets to 0.
if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1):
acc_S_mn[r, c] = -cutlass.Float32.inf
# (n_block_size)
acc_S_row = acc_S_mn[r, None].load()
# row_max_cur_row => f32
row_max_cur_row = acc_S_row.reduce(
cute.ReductionOp.MAX, -cutlass.Float32.inf, 0
)
# quad reduction for row_max
row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row)
row_max_prev_row = None
# if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row.
if cutlass.const_expr(not is_first_n_block):
row_max_prev_row = row_max_prev[r]
row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row)
if cutlass.const_expr(self._is_causal):
row_max_cur_row = (
0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row
)
# compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e))
acc_S_row_exp = cute.TensorSSA(
self._exp2f(
acc_S_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
),
tuple(acc_S_row.shape),
cutlass.Float32,
)
# acc_S_row_sum => f32
acc_S_row_sum = acc_S_row_exp.reduce(
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
)
# if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum.
if cutlass.const_expr(not is_first_n_block):
prev_minus_cur_exp = self._exp2f(
row_max_prev_row * softmax_params.softmax_scale_log2
- row_max_cur_row * softmax_params.softmax_scale_log2
)
acc_S_row_sum = (
acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp
)
acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp
# update row_max, row_sum and acc_S
softmax_params.row_max[r] = row_max_cur_row
softmax_params.row_sum[r] = acc_S_row_sum
acc_S_mn[r, None] = acc_S_row_exp
@cute.jit
def normalize_softmax(
self,
acc_O: cute.Tensor,
row_sum: cute.Tensor,
):
"""Normalize acc_O by row_sum.
:param acc_O: input tensor
:type acc_O: cute.Tensor
:param row_sum: row_sum tensor
:type row_sum: cute.Tensor
"""
# do quad reduction for row_sum.
acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
for r in cutlass.range_constexpr(cute.size(row_sum)):
row_sum[r] = self._threadquad_reduce_sum(row_sum[r])
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
scale = (
1.0 if acc_O_mn_row_is_zero_or_nan else cute.arch.rcp_approx(row_sum[r])
)
acc_O_mn[r, None] = acc_O_mn[r, None].load() * scale
def _make_acc_tensor_mn_view(self, acc: cute.Tensor) -> cute.Tensor:
"""make acc tensor as mn layout view
:param acc: input tensor
:type acc: cute.Tensor
:return: acc tensor mn layout view
:rtype: cute.Tensor
"""
acc_layout_col_major = cute.make_layout(acc.layout.shape)
acc_layout_mn = cute.make_layout(
(
(
acc_layout_col_major.shape[0][1],
acc_layout_col_major.shape[1],
), # MMA_M
(
acc_layout_col_major.shape[0][0],
acc_layout_col_major.shape[2],
), # MMA_N
),
stride=(
(
acc_layout_col_major.stride[0][1],
acc_layout_col_major.stride[1],
), # MMA_M
(
acc_layout_col_major.stride[0][0],
acc_layout_col_major.stride[2],
), # MMA_N
),
)
acc_layout_mn = cute.composition(acc.layout, acc_layout_mn)
return cute.make_tensor(acc.iterator, acc_layout_mn)
def _threadquad_reduce(self, val: cutlass.Float32, op: Callable) -> cutlass.Float32:
"""thread quad reduction
:param val: register value
:type val: cutlass.Float32
:param op: binary operator
:type op: Callable
:return: reduced value
:rtype: cutlass.Float32
"""
val = op(
val,
cute.arch.shuffle_sync_bfly(val, offset=2, mask=-1, mask_and_clamp=31),
)
val = op(
val,
cute.arch.shuffle_sync_bfly(val, offset=1, mask=-1, mask_and_clamp=31),
)
return val
def _threadquad_reduce_max(self, val: cutlass.Float32) -> cutlass.Float32:
"""thread quad reduction max
:param val: register value
:type val: cutlass.Float32
:return: max value
:rtype: cutlass.Float32
"""
return self._threadquad_reduce(val, lambda x, y: cute.arch.fmax(x, y))
def _threadquad_reduce_sum(self, val: cutlass.Float32) -> cutlass.Float32:
"""thread quad reduction sum
:param val: register value
:type val: cutlass.Float32
:return: sum value
:rtype: cutlass.Float32
"""
return self._threadquad_reduce(val, lambda x, y: x + y)
def _exp2f(
self, x: Union[cute.TensorSSA, cutlass.Float32]
) -> Union[cute.TensorSSA, cutlass.Float32]:
"""exp2f calculation for both vector and scalar.
:param x: input value
:type x: cute.TensorSSA or cutlass.Float32
:return: exp2 value
:rtype: cute.TensorSSA or cutlass.Float32
"""
if isinstance(x, cute.TensorSSA):
res = cute.make_fragment(x.shape, cutlass.Float32)
res.store(x)
for i in range(cute.size(x.shape)):
res[i] = self._exp2f(res[i])
return res.load()
return cute.arch.exp2(x)
def run(
dtype: Type[cutlass.Numeric],
batch_size: int,
seqlen_q: int,
seqlen_k: int,
num_head: int,
head_dim: int,
softmax_scale: float = 1.0,
m_block_size: int = 128,
n_block_size: int = 128,
num_threads: int = 128,
is_causal: bool = False,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
# Skip unsupported testcase
if not FlashAttentionForwardAmpere.can_implement(
dtype,
head_dim,
m_block_size,
n_block_size,
num_threads,
is_causal,
):
raise TypeError(
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
)
print(f"Running Ampere SM80 FlashAttentionForward test with:")
print(f" dtype: {dtype}")
print(f" batch_size: {batch_size}")
print(f" seqlen_q: {seqlen_q}")
print(f" seqlen_k: {seqlen_k}")
print(f" num_head: {num_head}")
print(f" head_dim: {head_dim}")
print(f" softmax_scale: {softmax_scale}")
print(f" m_block_size: {m_block_size}")
print(f" n_block_size: {n_block_size}")
print(f" num_threads: {num_threads}")
print(f" is_causal: {is_causal}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Create tensor Q/K/V/O
def create_tensor(
batch_size: int,
seqlen: int,
num_head: int,
head_dim: int,
dtype: Type[cutlass.Numeric],
) -> cute.Tensor:
# (batch_size, seqlen, num_head, head_dim)
shape = (batch_size, seqlen, num_head, head_dim)
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=cutlass_torch.dtype(dtype))
.cuda()
)
# assume input is 16B aligned.
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3,
stride_order=torch_tensor.dim_order(),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
fa2_fwd = FlashAttentionForwardAmpere(
head_dim,
m_block_size,
n_block_size,
num_threads,
is_causal,
)
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# compile the fa2 forward pass
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
if not skip_ref_check:
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
torch.cuda.synchronize()
q_ref = q_torch.permute(0, 2, 1, 3)
k_ref = k_torch.permute(0, 2, 1, 3)
v_ref = v_torch.permute(0, 2, 1, 3)
torch.backends.cuda.enable_flash_sdp(enabled=True)
ref_o = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
).permute(0, 2, 1, 3)
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
print("Results verified successfully!")
def generate_tensors():
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
return testing.JitArguments(
q_workspace,
k_workspace,
v_workspace,
o_workspace,
softmax_scale,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
q_torch.numel() * q_torch.element_size()
+ k_torch.numel() * k_torch.element_size()
+ v_torch.numel() * v_torch.element_size()
+ o_torch.numel() * o_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
compiled_fa2_fwd,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="example of flash attention v2 with CuTe on GPU"
)
parser.add_argument("--dtype", type=cutlass.dtype, default=cutlass.BFloat16)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--seqlen_q", type=int, default=8192)
parser.add_argument("--seqlen_k", type=int, default=8192)
parser.add_argument("--num_head", type=int, default=16)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--softmax_scale", type=float, default=0.5)
parser.add_argument("--m_block_size", type=int, default=128)
parser.add_argument("--n_block_size", type=int, default=64)
parser.add_argument("--num_threads", type=int, default=128)
parser.add_argument("--is_causal", action="store_true", help="Enable causal mask")
parser.add_argument("--warmup_iterations", type=int, default=3)
parser.add_argument("--iterations", type=int, default=10)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference check"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
run(
args.dtype,
args.batch_size,
args.seqlen_q,
args.seqlen_k,
args.num_head,
args.head_dim,
args.softmax_scale,
args.m_block_size,
args.n_block_size,
args.num_threads,
args.is_causal,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")