v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -64,7 +64,7 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla
ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not
enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do
that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB
to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C
to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C
which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the
data type of output ElementOutput (int32_t), the number of elements per vector memory access (16),
data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X +

View File

@ -64,7 +64,7 @@ ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t)
(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in
memory, we have to convey the layout of matrices. We do that by initializing template variable
LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row
major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel.
major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel.
We initialize template variable EpilogueOp, which takes the data type of output ElementOutput
(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t)
and data type of computation of linear combination (alpha * X + beta * C).

View File

@ -66,7 +66,7 @@ ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
computation of linear combination (alpha * X + beta * C).

View File

@ -177,7 +177,7 @@ public:
if(args.split_k_mode == SplitKMode::kParallel) {
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
// The user needs to call a reduction operator to optain the final output tensor
// The user needs to call a reduction operator to obtain the final output tensor
workspace_bytes =
sizeof(ElementAccumulator) *
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *

View File

@ -153,7 +153,7 @@ struct Options {
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
<< " run in a single kernel. Each individual problem in the group is subject to the same constraints that non-grouped\n"
<< " back-to-back GEMMs are subject to.s"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"

View File

@ -248,7 +248,7 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// Epilogue params remain constant across all problems in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
@ -402,7 +402,7 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// Epilogue params remain constant across all problems in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
@ -434,7 +434,7 @@ struct B2bGemm {
// Only row-major outputs are currently supported, so no transpose is performed
}
/// Returns non-grouped paramaters to be used as input to the kernel-level
/// Returns non-grouped parameters to be used as input to the kernel-level
/// operator for the problem indicated by problem_visitor.
CUTLASS_HOST_DEVICE
Params to_single_params(const ProblemVisitor& problem_visitor) const {

View File

@ -560,7 +560,7 @@ struct DefaultB2bConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
// multistage pipeline with interleaved layout.
template <
typename ElementA,

View File

@ -606,7 +606,7 @@ struct DefaultB2bConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
// multistage pipeline with interleaved layout.
/// Accumulator will be staged in shared memory.
template <

View File

@ -277,7 +277,7 @@ public:
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile
FragmentC0 const &src_accum, ///< source accumulator tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment

View File

@ -298,7 +298,7 @@ public:
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile
FragmentC0 const &src_accum, ///< source accumulator tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment

View File

@ -93,7 +93,7 @@ template <
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
/// Operation performed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp,

View File

@ -203,7 +203,7 @@ requires any memory for scratch space.
If yes, we reserve scratch space and pass it along
with other arguments to initialize the CUTLASS kernel.
After lauching the CUTLASS kernel, this example runs
After launching the CUTLASS kernel, this example runs
a reference convolution kernel (from CUTLASS utilities)
to check correctness.
*/

View File

@ -144,7 +144,7 @@ int run() {
// Construct Gemm ProblemSize with user defined output size
cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024};
// Stride factor shows the distance between two elements in the differnet dimensions. The
// Stride factor shows the distance between two elements in the different dimensions. The
// first data is the logical distance between two rows, the second is between two columns.
// CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory<Layout>::layout_factory
// to help to convert stride_factor to the two strides.

View File

@ -55,7 +55,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define the overal warp-level problem shape
// Define the overall warp-level problem shape
int const kM = 27;
int const kN = 31;
int const kK = 17;

View File

@ -59,7 +59,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define the overal warp-level problem shape
// Define the overall warp-level problem shape
int const kM = 14;
int const kN = 27;
int const kK = 17;

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
// This example fuses gather before GEMM and scatter after GEMM into the same
// GEMM kernel. Gather and scatter operation is controled by an index vector
// GEMM kernel. Gather and scatter operation is controlled by an index vector
// to select rows or columns from A, B, C or D matrices.
//
// Suppose, all matrices are column major. The pseudo code of the fused kernel

View File

@ -87,7 +87,7 @@ public:
using ElementLayernormCompute = ElementLayernormCompute_;
using ThreadblockShape = ThreadblockShape_;
// Pre-processing has ensured the layout equivelent to RowMajor
// Pre-processing has ensured the layout equivalent to RowMajor
using Layout = cutlass::layout::RowMajor;
using TensorVariance = TensorRef<ElementVariance, Layout>;

View File

@ -87,7 +87,7 @@ parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorC32RSK32"],

View File

@ -86,7 +86,7 @@ parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],

View File

@ -55,7 +55,7 @@
```
In practice, and for numerical stability reasons,
we also substract the maximum so far (`mi`) before doing
we also subtract the maximum so far (`mi`) before doing
the exponential. When we encounter new keys, the maximum
used to compute O so far (`m_prime`) can differ from the
current maximum, so we update O before accumulating with

View File

@ -55,7 +55,7 @@
```
In practice, and for numerical stability reasons,
we also substract the maximum so far (`mi`) before doing
we also subtract the maximum so far (`mi`) before doing
the exponential. When we encounter new keys, the maximum
used to compute O so far (`m_prime`) can differ from the
current maximum, so we update O before accumulating with

View File

@ -31,7 +31,7 @@
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
datastructures to instantiate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
@ -83,7 +83,7 @@ template <
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
/// Operation performed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {

View File

@ -522,7 +522,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
// For API compatibility with MmaMultistageFromSharedMemory
// but not supported as it worsens perf: older gpus < sm80 don't
// support async tranfers and have to waste registers
// support async transfers and have to waste registers
CUTLASS_DEVICE
void set_prologue_done(bool value) {}
CUTLASS_DEVICE

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Instanciates the right WarpIterator to read from shared memory
\brief Instantiates the right WarpIterator to read from shared memory
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
data dumped with `B2bGemm::accumToSmem`.
*/

View File

@ -86,7 +86,7 @@ namespace threadblock {
/// To be efficient, this assumes the iterator will be dereferenced and advanced
/// at least once outside any looping structure to minimize integer arithmetic.
///
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
/// dereferencing the iterator.
///
///

View File

@ -49,7 +49,7 @@
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
for this example:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
a_cols - Columns in the sparse matrix.
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)

View File

@ -153,7 +153,7 @@ class gen_device:
warp_M_tile = 32
# Determine maxmimum N_tile
# Determine maximum N_tile
Max_Ntile = 0
for layer in self.fuse_gemm_info:
n_tile = layer['mnk'][1]

View File

@ -76,9 +76,9 @@ class gen_verify:
)
def get_params(self, declartion = True):
def get_params(self, declaration = True):
code = ""
if declartion:
if declaration:
for param in self.params:
code += param[0] + " " + param[1] + ";\n"

View File

@ -64,8 +64,8 @@ def write_2_headfile(filename, file_dir, string):
with open(file_dir + filename, 'w') as f:
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
def var_idx(varaiable, index):
return varaiable + str(index)
def var_idx(variable, index):
return variable + str(index)
def list_2_string(input_list, ):

View File

@ -78,7 +78,7 @@
a single default value.
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
the collective is specified via the schedule tags that corresond to the underlying collective's
the collective is specified via the schedule tags that correspond to the underlying collective's
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
therefore letting the builder pick the collective specialization.

View File

@ -425,7 +425,7 @@ int main(int argc, char const **args) {
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
constexpr int PipelineStages = 8;
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
// Let's choose a Warp-Specialized Mainloop implementation which uses TMA
// Note : This requires / assumes the tensors to be 16B aligned
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
cutlass::gemm::KernelTmaWarpSpecialized>;

View File

@ -32,7 +32,7 @@
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
This example fuses gather before GEMM and scatter after GEMM into the same
GEMM kernel. Gather and scatter operation is controled by an index vector
GEMM kernel. Gather and scatter operation is controlled by an index vector
to select rows or columns from A, B, C or D matrices.
Gather/scatter operations are always performed along a strided dimension

View File

@ -65,7 +65,7 @@
The approach relies on two things:
- The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the
flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details).
- The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
- The hardware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
(almost) arbitrary strides, which can be used to represent a permuted view of the data.
In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags.

View File

@ -188,7 +188,7 @@ Running this example on an RTX 3080Ti prints the following performance numbers (
```
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.
Allocating tensors ... done.
Initializing data ... done.

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel
capable of operating on both affine and gather/scatter tensors.
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
@ -284,7 +284,7 @@ int ampere_gather_scatter_conv_fprop(
int
main(int argc, char const** argv) {
cutlass::CommandLine cmd(argc, argv);
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
std::cout << "Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.\n\n";
if (cmd.check_cmd_line_flag("help")) {
std::cout
<< "Options:\n"

View File

@ -291,7 +291,7 @@ struct Options {
// Post-process the problem sizes
bin_problems();
// Initalize alpha array
// Initialize alpha array
randomize_alpha_ptr_array(cmd);
}

View File

@ -358,7 +358,7 @@ void initialize(const Options<RasterOrderOptions> &options) {
// Layout SFA and SFB represent logically broadcasting data in CuTe.
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
// indecies in the tensor map to the same offset.
// indices in the tensor map to the same offset.
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));

View File

@ -61,7 +61,7 @@
# Heuristic mode with deterministic reduction
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
# Stream-K mode with determinsitic reduction
# Stream-K mode with deterministic reduction
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
# Split-K mode with a splitting factor of 2 and deterministic reduction

View File

@ -850,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
}
}
else {
std::cout << " Verfication is turned off for this run." << std::endl;
std::cout << " Verification is turned off for this run." << std::endl;
}
// Run profiling loop

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
Xformed Activation (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
where in terms of GEMM perspective,
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Activation (NZPQK)
where in terms of GEMM perspective,
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
Xformed Activation (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
where in terms of GEMM perspective,
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter

View File

@ -505,8 +505,12 @@ struct FwdRunner {
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
auto [Q, K, D, HB] = problem_shape;
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {

View File

@ -32,7 +32,7 @@
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
This example showcases the use of CUTLASS to build backward fused
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting
the NVIDIA Blackwell architecture.
Background and motivation
@ -117,6 +117,7 @@ struct Options {
std::vector<int> varlen_q;
std::vector<int> varlen_k;
int d = 128;
int d_vo = 128;
int iterations = 3;
bool verify = false;
bool verbose = false;
@ -178,6 +179,7 @@ struct Options {
}
cmd.get_cmd_line_argument("d", d, defaults.d);
cmd.get_cmd_line_argument("d_vo", d_vo, d);
cmd.get_cmd_line_argument("h", h, -1);
if (h == -1) h = 2048 / d;
@ -301,6 +303,7 @@ struct Options {
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
<< " --d=<int> Sets the D extent\n"
<< " --d_vo=<int> Sets the D_VO extent\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
@ -387,6 +390,7 @@ struct ExampleResult {
template<
bool kIsVarlen,
bool kIsMla,
class TileShape,
class DispatchPolicy,
class ActiveMask,
@ -404,8 +408,8 @@ struct BwdRunner {
// Q K D (H B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, cute::tuple<int, int>>
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, int, cute::tuple<int, int>>
>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
@ -461,45 +465,45 @@ struct BwdRunner {
// Methods
//
bool verify(const ProblemShape& problem_shape) {
auto [Q, K, D, HB] = problem_shape;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [H, B] = HB;
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
select<0,2,4>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
select<1,2,4>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
select<1,3,4>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
select<0,3,4>(problem_shape),
stride_O);
// keep going here! (this might be better in cursor)
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
select<0,4>(problem_shape),
stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
select<0,2,3>(problem_shape),
select<0,2,4>(problem_shape),
stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
select<1,2,3>(problem_shape),
select<1,2,4>(problem_shape),
stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
select<1,2,3>(problem_shape),
select<1,3,4>(problem_shape),
stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
select<0,2,3>(problem_shape),
select<0,3,4>(problem_shape),
stride_dO);
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
@ -595,14 +599,14 @@ struct BwdRunner {
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, {options.h, options.b}
options.d, options.d_vo, {options.h, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
@ -610,24 +614,25 @@ struct BwdRunner {
/// Initialize operands to be used in the GEMM and reference GEMM
ProblemShape initialize(Options const& options) {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, HB] = tensor_shape;
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
auto shape_QO = make_shape(Q, D, make_shape(H, B));
auto shape_KV = make_shape(K, D, make_shape(H, B));
auto shape_Q = make_shape(Q, D, make_shape(H, B));
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
auto shape_K = make_shape(K, D, make_shape(H, B));
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
auto shape_LSE = make_shape(Q, make_shape(H, B));
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
stride_V = stride_K;
stride_O = stride_Q;
stride_dQ = stride_Q;
stride_dK = stride_K;
stride_dV = stride_V;
@ -637,20 +642,20 @@ struct BwdRunner {
return size(make_shape(1ull, shape));
};
block_Q.reset(lsize(shape_QO));
block_K.reset(lsize(shape_KV));
block_V.reset(lsize(shape_KV));
block_O.reset(lsize(shape_QO));
block_Q.reset(lsize(shape_Q));
block_K.reset(lsize(shape_K));
block_V.reset(lsize(shape_V));
block_O.reset(lsize(shape_O));
block_LSE.reset(lsize(shape_LSE));
block_dQ.reset(lsize(shape_QO));
block_dK.reset(lsize(shape_KV));
block_dV.reset(lsize(shape_KV));
block_dO.reset(lsize(shape_QO));
block_dQ.reset(lsize(shape_Q));
block_dK.reset(lsize(shape_K));
block_dV.reset(lsize(shape_V));
block_dO.reset(lsize(shape_O));
block_ref_dQ.reset(lsize(shape_QO));
block_ref_dK.reset(lsize(shape_KV));
block_ref_dV.reset(lsize(shape_KV));
block_ref_dQ.reset(lsize(shape_Q));
block_ref_dK.reset(lsize(shape_K));
block_ref_dV.reset(lsize(shape_V));
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
@ -665,23 +670,23 @@ struct BwdRunner {
initialize_block(block_ref_dV, seed + 2035);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
select<0,2,4>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
select<1,2,4>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
select<1,3,4>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
select<0,3,4>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
select<0,4>(problem_shape),
stride_LSE);
if (! options.skip_reference) {
@ -698,7 +703,7 @@ struct BwdRunner {
ExampleResult example_result;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
typename Operation::Arguments arguments{
problem_shape,
@ -811,12 +816,12 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= static_cast<double>(get<2>(problem_shape));
flops *= static_cast<double>(get<3,0>(problem_shape));
flops *= static_cast<double>(get<3,1>(problem_shape));
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
flops *= static_cast<double>(get<4,0>(problem_shape));
flops *= static_cast<double>(get<4,1>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
example_result.runtime_ms = runtime_ms;
@ -892,7 +897,7 @@ template<class Mask>
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
BwdRunner<decltype(is_varlen)::value, false,decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
@ -900,7 +905,7 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -909,7 +914,7 @@ template<class Mask>
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
BwdRunner<decltype(is_varlen)::value, false, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
@ -917,7 +922,22 @@ void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
}
template<class Mask>
void run_bwd_mla_192(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, true, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
};
using HeadDim = _192;
run(Shape<_64, _128, HeadDim, _128>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -981,7 +1001,7 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;
@ -998,12 +1018,15 @@ int main_single(int argc, char const **args) {
};
with_causal([&](auto fusion) {
if (options.d <= 64) {
if (options.d <= 64 && options.d_vo == options.d) {
run_bwd_64(fusion, options, hw_info);
}
else if (options.d <= 128) {
else if (options.d <= 128 && options.d_vo == options.d) {
run_bwd_128(fusion, options, hw_info);
}
else if (options.d == 192 && options.d_vo == 128) {
run_bwd_mla_192(fusion, options, hw_info);
}
else {
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
}

View File

@ -485,7 +485,11 @@ struct MlaFwdRunner {
select<0,3>(problem_shape),
stride_LSE);
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
auto [Q, K, D, HB] = problem_shape;
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {

View File

@ -84,6 +84,8 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
@ -174,6 +176,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
TEST_BWD_MLA_BASIC
TEST_BWD_MLA_VARLEN
)
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})

View File

@ -37,13 +37,19 @@ There are three kernels to compute backwards:
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
## MLA Blackwell Backward
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
# MLA Inference for Blackwell
This sample provides code for fused multi-head latent attention inference in
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
It supports fp16, bf16, and fp8 input and output types.
To accomodate the large output accumulator due to the large latent head dimension,
To accommodate the large output accumulator due to the large latent head dimension,
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`

View File

@ -39,6 +39,7 @@
#include "../device/fmha.hpp"
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
@ -55,13 +56,14 @@ template<
class Element,
class ElementAccumulator,
class TileShape,
bool IsMla,
class Mask
>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D HB
// Q K D D_VO HB
ProblemShape problem_shape;
const Element* ptr_Q;
@ -98,11 +100,20 @@ public:
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
>;
using Operation = cutlass::fmha::device::FMHA<
using OperationNormal= cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
using Kernel = typename Operation::Kernel;
struct Params {
@ -121,7 +132,7 @@ private:
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
@ -141,7 +152,7 @@ private:
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
@ -163,6 +174,7 @@ private:
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
{ args.ptr_Q, args.stride_Q,
@ -207,7 +219,7 @@ public:
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
@ -227,7 +239,7 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
@ -256,7 +268,7 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, HB] = args.problem_shape;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment

View File

@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
return grid;
}
@ -103,7 +103,7 @@ struct FmhaKernelBwdConvert {
}
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
@ -120,7 +120,7 @@ struct FmhaKernelBwdConvert {
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
@ -139,13 +139,13 @@ struct FmhaKernelBwdConvert {
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));
}
}
};

View File

@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape));
return grid;
}
@ -131,7 +131,7 @@ struct FmhaKernelBwdSumOdO {
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];

View File

@ -344,12 +344,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static bool can_implement(Arguments const& args) {
auto [Q, K, D, HB] = args.problem_shape;
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0) {
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
return false;
}
if (D % Alignment != 0) {
if (D % Alignment != 0 || D_VO % Alignment != 0) {
return false;
}
return true;
@ -362,7 +362,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static Params to_underlying_arguments(Arguments const& args, void*) {
auto [Q_, K_, D, HB] = args.problem_shape;
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
int Q = Q_;
int K = K_;
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}, /*workspace=*/nullptr);
auto params_vdo = CollectiveMmaVDO::to_underlying_arguments(
make_shape(K, Q, D, HB),
make_shape(K, Q, D_VO, HB),
typename CollectiveMmaVDO::Arguments {
args.mainloop.ptr_v, args.mainloop.stride_v,
args.mainloop.ptr_do, args.mainloop.stride_do,
@ -446,21 +446,21 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [Q, K, D, D_VO, HB] = problem_shape;
using X = Underscore;
uint16_t mcast_mask = 0;
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
@ -681,7 +681,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
@ -974,11 +974,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -988,7 +988,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -1003,7 +1003,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
}
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
if (elem_less(cDV(i), select<1,2>(problem_shape))) {
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
gDV(i) = Element(0);
}
}
@ -1020,8 +1020,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
@ -1029,7 +1029,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDKtDK.data() = TmemAllocation::kDK;
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -1065,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDVtDV.data() = TmemAllocation::kDV;
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
(_, _, blk_coord_k, _0{}, blk_coord_batch);
@ -1088,7 +1088,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
// store tDVgDV
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,2>(problem_shape));
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
@ -1140,7 +1140,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape;
auto [Q, K, D, D_VO, HB] = problem_shape;
// in tmem, S & P overlap
// and dP and dQ overlap
@ -1396,9 +1396,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using X = Underscore;
auto [Q, K, D, HB] = problem_shape;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
// must match TileShapeDQ
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
@ -1676,7 +1676,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
@ -1809,7 +1809,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, HB] = params.problem_shape;
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
return grid;

View File

@ -33,7 +33,9 @@
#pragma once
#include "cute/tensor.hpp"
#include "collective/fmha_fusion.hpp"
using namespace cutlass::fmha::collective;
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
@ -61,20 +63,20 @@ void __global__ fmha_bwd_reference_dQ_kernel(
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
make_coord(_0{}, _0{}, _0{}, _0{},idx2crd(idx_L, get<4>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
@ -82,10 +84,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
@ -135,20 +142,20 @@ void __global__ fmha_bwd_reference_dK_kernel(
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
@ -156,10 +163,14 @@ void __global__ fmha_bwd_reference_dK_kernel(
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
@ -209,20 +220,20 @@ void __global__ fmha_bwd_reference_dV_kernel(
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
@ -244,7 +255,7 @@ void __global__ fmha_bwd_reference_dV_kernel(
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
ElementAcc rS = static_cast<Element>(mS[idx_Q]);

View File

@ -62,19 +62,20 @@ void __global__ fmha_reference_kernel(
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mQ)));
auto id = make_identity_tensor(make_shape(1, 1));
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) {
auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in));
auto coord_L = idx2crd(idx_L, shape<4>(problem_shape_in));
auto get_coord_in = [&]() {
if constexpr (rank_v<decltype(get<2>(ProblemShapeIn{}))> == 2) {
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), coord_L);
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), cute::make_tuple(_0{}, _0{}), coord_L);
} else {
return cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
return cute::make_tuple(idx_Q, _0{}, _0{}, _0{}, coord_L);
}
};
auto coord_in = get_coord_in();
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in));
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<4,1>(coord_in));
int head_qk = 0;
int head_v = 0;
@ -83,7 +84,7 @@ void __global__ fmha_reference_kernel(
head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape);
head_v = size<2, 0>(problem_shape);
} else {
head_qk = size<2>(problem_shape);
head_qk = size<3>(problem_shape);
head_v = head_qk;
}
@ -157,6 +158,7 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
}
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
}

View File

@ -835,7 +835,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
}
}
else {
std::cout << " Verfication is turned off for this run." << std::endl;
std::cout << " Verification is turned off for this run." << std::endl;
}
// Run profiling loop

View File

@ -259,7 +259,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
@ -394,7 +394,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;

View File

@ -295,7 +295,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
@ -433,7 +433,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;

View File

@ -333,7 +333,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
@ -471,7 +471,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;

View File

@ -328,7 +328,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
@ -473,7 +473,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;

View File

@ -341,7 +341,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// Step 2: The Mainloop.
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
@ -527,7 +527,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yeilds the CTA-local work.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;

View File

@ -200,7 +200,7 @@ int main(int argc, char** argv)
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// Note, this assumes the vector and thread layouts are aligned with contiguous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative value layouts are also possible, though incompatible layouts
// will result in compile time errors.

View File

@ -90,18 +90,17 @@ If you already know the TV layout you want to use for your tiled copy, CuTe DSL
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
gA = cute.zipped_divide(mA, tiler_mn)
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
thread index and inter-thread index of data array per thread to logical coordinates of elements in
input and output tensors.
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which
infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread
block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per
thread to logical coordinates of elements in input and output tensors.
.. code-block:: python
blkA = gA[((None, None), bidx)] # (TileM,TileN)
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
# get slice of tiled_copy_A for current thread
thr_copy_A = tiled_copy_A.get_slice(tidx)
@ -140,8 +139,8 @@ def elementwise_add_kernel(
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
tv_layout: cute.Layout,
tiler_mn: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
@ -165,9 +164,9 @@ def elementwise_add_kernel(
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
@ -254,7 +253,7 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
cC = cute.zipped_divide(idC, tiler=tiler_mn)
print(f"[DSL INFO] coord tensor = {cC.type}")
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
)
@ -362,7 +361,7 @@ def run_elementwise_add(
workspace_generator=generate_tensors,
workspace_count=10,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
)
# Print execution results

View File

@ -353,7 +353,7 @@ def run_elementwise_apply_and_verify(
current_stream,
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
use_cuda_graphs=True,
stream=current_stream,
)

View File

@ -32,13 +32,13 @@ from typing import Type, Union, Callable
import torch
import cuda.bindings.driver as cuda
import cutlass.cute.testing as testing
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, warp
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
import cutlass.utils.ampere_helpers as sm80_utils
import cutlass.utils as utils
"""
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
@ -163,7 +163,7 @@ class FlashAttentionForwardAmpere:
# Check if block size setting is out of shared memory capacity
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
smem_capacity = sm80_utils.SMEM_CAPACITY["sm80"]
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
if smem_usage > smem_capacity:
return False
@ -469,21 +469,9 @@ class FlashAttentionForwardAmpere:
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
self._dtype,
)
smem_tiled_copy_Q = cute.make_tiled_copy(
smem_copy_atom_Q,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_K = cute.make_tiled_copy(
smem_copy_atom_K,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_V = cute.make_tiled_copy(
smem_copy_atom_V,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
@ -702,11 +690,7 @@ class FlashAttentionForwardAmpere:
cute.nvgpu.CopyUniversalOp(), self._dtype
)
# tiled copy atom for O
smem_tiled_copy_O = cute.make_tiled_copy(
smem_copy_atom_O,
layout_tv=tiled_mma.tv_layout_C_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
)
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
@ -1178,7 +1162,7 @@ class FlashAttentionForwardAmpere:
return cute.arch.exp2(x)
def run_flash_attention_fwd(
def run(
dtype: Type[cutlass.Numeric],
batch_size: int,
seqlen_q: int,
@ -1193,6 +1177,8 @@ def run_flash_attention_fwd(
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
# Skip unsupported testcase
if not FlashAttentionForwardAmpere.can_implement(
@ -1207,6 +1193,23 @@ def run_flash_attention_fwd(
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
)
print(f"Running Ampere SM80 FlashAttentionForward test with:")
print(f" dtype: {dtype}")
print(f" batch_size: {batch_size}")
print(f" seqlen_q: {seqlen_q}")
print(f" seqlen_k: {seqlen_k}")
print(f" num_head: {num_head}")
print(f" head_dim: {head_dim}")
print(f" softmax_scale: {softmax_scale}")
print(f" m_block_size: {m_block_size}")
print(f" n_block_size: {n_block_size}")
print(f" num_threads: {num_threads}")
print(f" is_causal: {is_causal}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Create tensor Q/K/V/O
def create_tensor(
batch_size: int,
@ -1217,22 +1220,28 @@ def run_flash_attention_fwd(
) -> cute.Tensor:
# (batch_size, seqlen, num_head, head_dim)
shape = (batch_size, seqlen, num_head, head_dim)
return (
torch.empty(*shape, dtype=torch.int32).random_(-2, 2).to(dtype=dtype).cuda()
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=cutlass_torch.dtype(dtype))
.cuda()
)
# assume input is 16B aligned.
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3,
stride_order=torch_tensor.dim_order(),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
q = create_tensor(
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
)
k = create_tensor(
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
)
v = create_tensor(
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
)
o = create_tensor(
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
)
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
fa2_fwd = FlashAttentionForwardAmpere(
head_dim,
@ -1241,78 +1250,63 @@ def run_flash_attention_fwd(
num_threads,
is_causal,
)
# assume input is 16B align.
q_tensor = (
from_dlpack(q, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=q.dim_order(), divisibility=(128 // dtype.width)
)
)
k_tensor = (
from_dlpack(k, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=k.dim_order(), divisibility=(128 // dtype.width)
)
)
v_tensor = (
from_dlpack(v, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=v.dim_order(), divisibility=(128 // dtype.width)
)
)
o_tensor = (
from_dlpack(o, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=o.dim_order(), divisibility=(128 // dtype.width)
)
)
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# compile the fa2 forward pass
compiled_fa2_fwd = cute.compile(
fa2_fwd, q_tensor, k_tensor, v_tensor, o_tensor, softmax_scale, current_stream
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
if not skip_ref_check:
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
torch.cuda.synchronize()
q_ref = q_torch.permute(0, 2, 1, 3)
k_ref = k_torch.permute(0, 2, 1, 3)
v_ref = v_torch.permute(0, 2, 1, 3)
torch.backends.cuda.enable_flash_sdp(enabled=True)
ref_o = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
).permute(0, 2, 1, 3)
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
print("Results verified successfully!")
def generate_tensors():
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
return testing.JitArguments(
q_workspace,
k_workspace,
v_workspace,
o_workspace,
softmax_scale,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
q_torch.numel() * q_torch.element_size()
+ k_torch.numel() * k_torch.element_size()
+ v_torch.numel() * v_torch.element_size()
+ o_torch.numel() * o_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
compiled_fa2_fwd,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# warmup
for _ in range(warmup_iterations):
compiled_fa2_fwd(
q_tensor,
k_tensor,
v_tensor,
o_tensor,
softmax_scale,
current_stream,
)
# run the compiled fa2 forward pass
for _ in range(iterations):
compiled_fa2_fwd(
q_tensor,
k_tensor,
v_tensor,
o_tensor,
softmax_scale,
current_stream,
)
torch.cuda.synchronize()
if skip_ref_check:
return
# reference implementation
q_ref = q.permute(0, 2, 1, 3)
k_ref = k.permute(0, 2, 1, 3)
v_ref = v.permute(0, 2, 1, 3)
torch.backends.cuda.enable_flash_sdp(enabled=True)
ref_o = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
).permute(0, 2, 1, 3)
torch.testing.assert_close(o.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -1334,9 +1328,15 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference check"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
run_flash_attention_fwd(
run(
args.dtype,
args.batch_size,
args.seqlen_q,
@ -1348,6 +1348,10 @@ if __name__ == "__main__":
args.n_block_size,
args.num_threads,
args.is_causal,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -634,16 +634,50 @@ class SGemm:
return
def main(
def run(
mnk: Tuple[int, int, int],
a_major: str,
b_major: str,
c_major: str,
problem_shape: Tuple[int, int, int],
static_shape: bool = False,
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
M, N, K = problem_shape
"""Execute SIMT GEMM operation and benchmark performance.
:param mnk: GEMM problem size (M, N, K, L)
:type mnk: Tuple[int, int, int, int]
:param a_major: Memory layout of tensor A
:type a_major: str
:param b_major: Memory layout of tensor B
:type b_major: str
:param c_major: Memory layout of tensor C
:type c_major: str
:param static_shape: Whether to use static shape optimization, defaults to False
:type static_shape: bool, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 100
:type iterations: int, optional
:param skip_ref_check: Skip validation against reference implementation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Ampere SIMT GEMM example:")
print(f"mnk: {mnk}")
print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}")
print(f"Static shape: {static_shape}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
M, N, K = mnk
# Create and permute tensor A/B/C
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
@ -710,20 +744,6 @@ def main(
print("Executing GEMM kernel...")
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(
a_tensor, b_tensor, c_tensor, current_stream
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=False,
stream=current_stream,
)
# Print execution results
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
torch.cuda.synchronize()
@ -732,6 +752,71 @@ def main(
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
# Create new tensors for each workspace to ensure cold L2 cache
a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
if static_shape:
a_tensor_workspace = (
from_dlpack(a_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
divisibility=divisibility_a,
)
)
else:
a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16)
b_tensor_workspace = (
from_dlpack(b_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
divisibility=divisibility_b,
)
)
c_tensor_workspace = (
from_dlpack(c_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
divisibility=divisibility_c,
)
)
return testing.JitArguments(
a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a.numel() * a.element_size()
+ b.numel() * b.element_size()
+ c.numel() * c.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# Print execution results
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
@ -753,19 +838,27 @@ if __name__ == "__main__":
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
print("Running SIMT GEMM example:")
torch.manual_seed(1024)
main(
run(
args.mnk,
args.a_major,
args.b_major,
args.c_major,
args.mnk,
args.static_shape,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
- Threadblock rasterization to improve data re-use
- Supports multi-stage pipeline to overlap computation and memory access
- Implements shared memory buffering for epilogue to increase coalesed global memory access
- Implements shared memory buffering for epilogue to increase coalesced global memory access
This GEMM works as follows:
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
@ -214,7 +214,7 @@ class TensorOpGemm:
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
)
# Creates a synchonous copy atom and thread layouts for the epilogue
# Creates a synchronous copy atom and thread layouts for the epilogue
c_copy_bits = 128
atom_sync_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@ -550,16 +550,8 @@ class TensorOpGemm:
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
@ -836,8 +828,7 @@ class TensorOpGemm:
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((copy_elems, 1))
)
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
def raster_tile(self, i, j, f):
new_i = i // f
@ -845,20 +836,33 @@ class TensorOpGemm:
return (new_i, new_j)
def run_tensor_op_gemm(
def run(
a_major: str,
b_major: str,
c_major: str,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
problem_shape: Tuple[int, int, int, int],
mnkl: Tuple[int, int, int, int],
atom_layout_mnk: Tuple[int, int, int],
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
M, N, K, L = problem_shape
print(f"Running Ampere tensor core GEMM example:")
print(f"mnkl: {mnkl}")
print(
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Atoms layout: {atom_layout_mnk}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
M, N, K, L = mnkl
# Create and permute tensor A/B/C
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
# else: (l, mode0, mode1) -> (mode0, mode1, l)
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
return (
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=dtype)
.to(dtype=cutlass_torch.dtype(dtype))
.permute(permute_order)
.cuda()
)
# assume input is 16B aligned
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
.mark_compact_shape_dynamic(
mode=(1 if not is_mode0_major else 0),
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
a = create_and_permute_tensor(
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
)
b = create_and_permute_tensor(
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
)
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
tensor_op_gemm = TensorOpGemm(
ab_dtype,
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
atom_layout_mnk,
)
# assume input is 16B aligned
a_tensor = (
from_dlpack(a, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
b_tensor = (
from_dlpack(b, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
c_tensor = (
from_dlpack(c, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
divisibility=(128 // c_dtype.width),
)
)
print("Compiling kernel with cute.compile ...")
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
print("Executing GEMM kernel...")
if not skip_ref_check:
ref = torch.einsum(
"mkl,nkl->mnl",
a_torch.to(dtype=torch.float32),
b_torch.to(dtype=torch.float32),
).to(cutlass_torch.dtype(c_dtype))
compiled_gemm(mA, mB, mC)
print("Verifying results...")
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
use_cuda_graphs=False,
)
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
print("Verifying results...")
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
@ -985,10 +987,15 @@ if __name__ == "__main__":
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
print("Running Ampere tensor core GEMM example:")
run_tensor_op_gemm(
run(
args.a_major,
args.b_major,
args.c_major,
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

File diff suppressed because it is too large Load Diff

View File

@ -212,7 +212,7 @@ class DenseGemmKernel:
self.occupancy = 1
self.threads_per_cta = 128
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1106,11 +1106,7 @@ class DenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1772,7 +1768,7 @@ def run_dense_gemm(
ref_c = ref
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(

View File

@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.cute.testing as testing
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
self.cta_sync_bar_id = 0
self.epilog_sync_bar_id = 1
self.tmem_ptr_sync_bar_id = 2
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
return can_implement
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
@ -1832,17 +1829,58 @@ def run_dense_gemm(
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_2cta_instrs: bool,
use_tma_store: bool,
tolerance: float,
mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Tuple[int, int] = (2, 1),
use_2cta_instrs: bool = True,
use_tma_store: bool = True,
tolerance: float = 1e-01,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
This function prepares input tensors, configures and launches the persistent GEMM kernel,
optionally performs reference validation, and benchmarks the execution performance.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param ab_dtype: Data type for input tensors A and B
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type mma_tiler_mn: Tuple[int, int], optional
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type cluster_shape_mn: Tuple[int, int], optional
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_2cta_instrs: bool, optional
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_tma_store: bool, optional
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
:type tolerance: float, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:raises RuntimeError: If CUDA GPU is not available
:raises ValueError: If the configuration is invalid or unsupported by the kernel
:return: Execution time of the GEMM kernel
:rtype: float
"""
print(f"Running Blackwell Persistent Dense GEMM test with:")
print(f"mnkl: {mnkl}")
@ -1855,6 +1893,7 @@ def run_dense_gemm(
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Unpack parameters
m, n, k, l = mnkl
@ -1931,15 +1970,15 @@ def run_dense_gemm(
is_dynamic_layout=is_dynamic_layout,
)
return f32_torch_tensor, cute_tensor, torch_tensor
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
a_ref, a_tensor, a_torch = create_and_permute_tensor(
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
)
b_ref, b_tensor, b_torch = create_and_permute_tensor(
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
)
c_ref, c_tensor, c_torch = create_and_permute_tensor(
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
)
@ -1967,16 +2006,8 @@ def run_dense_gemm(
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
)
# Launch GPU kernel
# Warm up
for i in range(warmup_iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Execution
for i in range(iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Compute reference result
if not skip_ref_check:
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
if ab_dtype in {
cutlass.Int8,
cutlass.Uint8,
@ -2028,6 +2059,40 @@ def run_dense_gemm(
rtol=1e-05,
)
def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, _ = cutlass_torch.cute_tensor_like(
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
)
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
run_dense_gemm(
run(
args.mnkl,
args.ab_dtype,
args.c_dtype,
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -223,7 +223,7 @@ class DenseGemmKernel:
self.occupancy = 1
self.threads_per_cta = 128
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1063,11 +1063,7 @@ class DenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)

View File

@ -43,6 +43,7 @@ import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
@ -90,7 +91,7 @@ Constraints for this example:
* Number of heads in Q must be divisible by number of heads in K
* mma_tiler_mn must be 128,128
* Batch size must be the same for Q, K, and V tensors
* For causal masking, use --has_casual_mask (note: specify without =True/False)
* For causal masking, use --is_causal (note: specify without =True/False)
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
"""
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
smem_copy_atom = sm100_utils.get_smem_store_op(
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
)
tiled_smem_store = cute.make_tiled_copy(
smem_copy_atom,
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
tiler_mn=tiled_tmem_load.tiler_mn,
)
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
return tile_sched_params, grid
def run_fmha_and_verify(
def run(
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
in_dtype: Type[cutlass.Numeric],
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
pv_acc_dtype: Type[cutlass.Numeric],
mma_tiler_mn: Tuple[int, int],
is_persistent: bool,
has_casual_mask: bool,
is_causal: bool,
scale_q: float,
scale_k: float,
scale_v: float,
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
:type mma_tiler_mn: Tuple[int, int]
:param is_persistent: Whether to use persistent kernel optimization
:type is_persistent: bool
:param has_casual_mask: Whether to apply causal masking
:type has_casual_mask: bool
:param is_causal: Whether to apply causal masking
:type is_causal: bool
:param scale_q: Scaling factor for query tensor
:type scale_q: float
:param scale_k: Scaling factor for key tensor
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
:type iterations: int
:param skip_ref_check: Skip validation against reference implementation
:type skip_ref_check: bool
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
:type use_cold_l2: bool
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
:raises RuntimeError: If GPU is unavailable for computation
:return: Execution time of the FMHA kernel in microseconds
:rtype: float
"""
print(f"Running Blackwell SM100 FMHA test with:")
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
print(f" pv_acc_dtype: {pv_acc_dtype}")
print(f" mma_tiler_mn: {mma_tiler_mn}")
print(f" is_persistent: {is_persistent}")
print(f" has_casual_mask: {has_casual_mask}")
print(f" is_causal: {is_causal}")
print(f" scale_q: {scale_q}")
print(f" scale_k: {scale_k}")
print(f" scale_v: {scale_v}")
print(f" inv_scale_o: {inv_scale_o}")
print(f" scale_softmax: {scale_softmax}")
print(f" tolerance: {tolerance}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Unpack parameters
b, s_q, h_q, d = q_shape
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
mma_tiler = (*mma_tiler_mn, d)
mask_type = MaskType.NO_MASK
if has_casual_mask:
if is_causal:
mask_type = MaskType.CAUSAL_MASK
else:
if isinstance(s_k, tuple):
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
# Warmup
for _ in range(warmup_iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
# Execute kernel
for _ in range(iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
torch.cuda.synchronize()
def run_torch_fmha(
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
):
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
h_q = q.shape[2]
h_k = k.shape[2]
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
v = v.transpose(1, 2)
# For the situation that torch has not supported, we need to handle it manually
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
situation1 = is_causal and (q.is_nested or k.is_nested)
situation2 = (q.is_nested and not k.is_nested) or (
not q.is_nested and k.is_nested
)
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref_i = ref_i.transpose(0, 1) * scale_output
ref_list.append(ref_i)
if q.is_nested:
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref = ref.transpose(1, 2) * scale_output
ref = ref.transpose(1, 2) * scale_output
return ref
if not skip_ref_check:
# Execute kernel once for reference checking
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
print("Verifying results...")
o_ref = run_torch_fmha(
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
)
if o_ref.is_nested:
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
_, q_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
in_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
_, k_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, v_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, o_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
out_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
return testing.JitArguments(
q_tensor_workspace.iterator,
k_tensor_workspace.iterator,
v_tensor_workspace.iterator,
o_tensor_workspace.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
workspace_count = 1
if use_cold_l2:
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
one_workspace_bytes = (
q_torch_effective.numel() * q_torch_effective.element_size()
+ k_torch_effective.numel() * k_torch_effective.element_size()
+ v_torch_effective.numel() * v_torch_effective.element_size()
+ o_torch_effective.numel() * o_torch_effective.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_fmha,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str):
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--has_casual_mask",
"--is_causal",
action="store_true",
help="Whether to use casual mask",
)
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
help="Skip reference check",
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
if len(args.q_shape) != 4:
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
torch.manual_seed(1111)
run_fmha_and_verify(
run(
args.q_shape,
args.k_shape,
args.in_dtype,
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
args.pv_acc_dtype,
args.mma_tiler_mn,
args.is_persistent,
args.has_casual_mask,
args.is_causal,
args.scale_q,
args.scale_k,
args.scale_v,
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
@ -157,7 +158,7 @@ class GroupedGemmKernel:
self.tmem_ptr_sync_bar_id = 2
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
self.tensormap_ab_init_bar_id = 4
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
self.num_tma_load_bytes = 0
def _setup_attributes(self):
@ -951,7 +952,7 @@ class GroupedGemmKernel:
# Specialized MMA warp
#
if warp_idx == self.mma_warp_id:
# initilize tensormap A, B for TMA warp
# initialize tensormap A, B for TMA warp
if cutlass.const_expr(self.delegate_tensormap_ab_init):
tensormap_manager.init_tensormap_from_atom(
tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id
@ -1540,11 +1541,7 @@ class GroupedGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1815,7 +1812,136 @@ class GroupedGemmKernel:
tensor_memory_management_bytes = 12
def run_grouped_gemm(
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
torch_tensor_cpu: torch.Tensor = None,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
"""Create a GPU tensor from scratch or based on an existing CPU tensor.
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
:type torch_tensor_cpu: torch.Tensor, optional
"""
if torch_tensor_cpu is None:
# Create new CPU tensor
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
# Create GPU tensor from CPU tensor (new or existing)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
def create_tensors_for_all_groups(
problem_sizes_mnkl: List[tuple[int, int, int, int]],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
torch_fp32_tensors_abc: List[List[torch.Tensor]] = None,
) -> tuple[
List[List[int]],
List[List[torch.Tensor]],
List[tuple],
List[List[tuple]],
List[List[torch.Tensor]],
]:
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
problem_sizes_mnkl
):
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
# Initialize lists to store tensors for all groups
new_torch_fp32_tensors_abc = (
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
)
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
# Iterate through all groups and create tensors for each group
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
# Get existing CPU tensors if available, otherwise None
existing_cpu_a = (
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
)
existing_cpu_b = (
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
)
existing_cpu_c = (
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
)
# Create tensors (reusing CPU tensors if provided)
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
tensor_fp32_a,
stride_mk_a,
) = create_tensor_and_stride(
l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a
)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
tensor_fp32_b,
stride_nk_b,
) = create_tensor_and_stride(
l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b
)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
tensor_fp32_c,
stride_mn_c,
) = create_tensor_and_stride(
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
)
# Only append to new_torch_fp32_tensors_abc if we created new CPU tensors
if torch_fp32_tensors_abc is None:
new_torch_fp32_tensors_abc.append(
[tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]
)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append(
(
cute_tensor_a,
cute_tensor_b,
cute_tensor_c,
)
)
return (
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
new_torch_fp32_tensors_abc,
)
def run(
num_groups: int,
problem_sizes_mnkl: tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
@ -1832,8 +1958,16 @@ def run_grouped_gemm(
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""Run grouped GEMM example with specified configurations."""
"""Run grouped GEMM example with specified configurations.
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Blackwell Grouped GEMM test with:")
print(f"{num_groups} groups")
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
@ -1847,6 +1981,7 @@ def run_grouped_gemm(
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Skip unsupported types
if ab_dtype not in {
@ -1902,66 +2037,22 @@ def run_grouped_gemm(
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
# Create tensors for all groups using the new function
(
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
torch_fp32_tensors_abc,
) = create_tensors_for_all_groups(
problem_sizes_mnkl,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
)
# iterate all groups and create tensors for each group
torch_fp32_tensors_abc = []
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
tensor_fp32_a,
stride_mk_a,
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
tensor_fp32_b,
stride_nk_b,
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
tensor_fp32_c,
stride_mn_c,
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append(
(
cute_tensor_a,
cute_tensor_b,
cute_tensor_c,
)
)
# Choose A, B, C with the smallest size to create initial tensormaps
key_size_a = lambda item: item[1][0] * item[1][2]
key_size_b = lambda item: item[1][1] * item[1][2]
@ -2078,36 +2169,19 @@ def run_grouped_gemm(
current_stream,
)
# Launch GPU kernel
# Warm up
for _ in range(warmup_iterations):
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
# Execution
for i in range(iterations):
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
torch.cuda.synchronize()
# Compute reference result
if not skip_ref_check:
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
# Compute reference result
for i, (a, b, c) in enumerate(torch_tensors_abc):
ref = torch.einsum(
"mkl,nkl->mnl",
@ -2122,6 +2196,102 @@ def run_grouped_gemm(
rtol=1e-05,
)
def generate_tensors():
# Reuse existing CPU tensors and create new GPU tensors from them
(
ptrs_abc_workspace,
torch_tensors_abc_workspace,
cute_tensors_abc_workspace,
strides_abc_workspace,
_,
) = create_tensors_for_all_groups(
problem_sizes_mnkl,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
torch_fp32_tensors_abc,
)
initial_cute_tensors_abc_workspace = [
cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k)
cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k)
cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n)
]
# Create new tensors for this workspace
tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc_workspace, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc_workspace, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
tensormap_workspace, _ = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
)
return testing.JitArguments(
initial_cute_tensors_abc_workspace[0],
initial_cute_tensors_abc_workspace[1],
initial_cute_tensors_abc_workspace[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc_workspace,
tensor_of_ptrs_abc_workspace,
tensormap_workspace,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
sum(
[
sum(
[
torch_tensor.numel() * torch_tensor.element_size()
for torch_tensor in group_tensors
]
)
for group_tensors in torch_tensors_abc
]
)
+
# Add size of strides tensor
tensor_of_strides_abc_torch.numel()
* tensor_of_strides_abc_torch.element_size()
+
# Add size of ptrs tensor
tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size()
+
# Add size of tensormap tensor
tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_grouped_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -2218,6 +2388,12 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -2248,7 +2424,7 @@ if __name__ == "__main__":
torch.manual_seed(2025)
run_grouped_gemm(
run(
args.num_groups,
args.problem_sizes_mnkl,
args.ab_dtype,
@ -2265,5 +2441,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -29,13 +29,14 @@
import argparse
from typing import List, Type, Tuple, Optional
from cuda import cuda
import cuda.bindings.driver as cuda
import torch
import torch.nn.functional as F
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
from .mamba2_ssd_reference import (
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent))
from mamba2_ssd_reference import (
ssd_reference_fp32_all,
ssd_reference_lowprecision_intermediates,
analyze_relative_diffs,
)
from .mamba2_ssd_tile_scheduler import (
from mamba2_ssd_tile_scheduler import (
Mamba2SSDTileSchedulerParams,
Mamba2SSDTileScheduler,
)
@ -122,7 +126,7 @@ class SSDKernel:
*self.epilog_warp_id,
)
)
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
# Named barriers
self.pre_inter_sync_bar_id = 1
@ -1522,7 +1526,10 @@ class SSDKernel:
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r
local_tidx,
smem_bt_internal_,
tiled_s2r_b,
tBrB_s2r,
)
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
@ -3053,7 +3060,7 @@ class SSDKernel:
# SegSum
# fadd2 + fsel + fmul2/mufu + fmul2
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
(
tCompute[subtile_idx],
tCompute[subtile_idx + 1],
@ -3061,11 +3068,11 @@ class SSDKernel:
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
)
for subtile_idx in range(cute.size(tTR_rQ)):
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
m, n = tCoord[subtile_idx]
if m < n:
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
# TODO: use math.exp directly
(
tCompute[subtile_idx],
@ -3130,11 +3137,7 @@ class SSDKernel:
dtype,
num_bits_per_copy=128,
)
tiled_r2s_b = cute.make_tiled_copy(
copy_atom_r2s_b,
layout_tv=tiled_s2r_b.layout_tv_tiled,
tiler_mn=tiled_s2r_b.tiler_mn,
)
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
# Partition shared tensor for smem store Bt
@ -3333,17 +3336,24 @@ class SSDKernel:
)
def run_ssd(
def run(
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
io_dtype: Type[cutlass.Numeric],
cumsum_delta_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
has_d: bool,
d_has_hdim: bool,
fuse_scale_d: str,
tolerance: float,
print_rtol_stats: bool,
ref_lower_precision: bool,
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
has_d = fuse_scale_d != "none"
d_has_hdim = fuse_scale_d == "vector"
print(f"Running B100 Mamba2 SSD with:")
print(f"GBEHCDLN: {gbehcdln}")
print(
@ -3353,6 +3363,10 @@ def run_ssd(
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
)
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Unpack parameters
G, B, E, H, C, D, L, N = gbehcdln
@ -3515,39 +3529,146 @@ def run_ssd(
stream,
)
# Launch compiled ssd kernel
compiled_ssd(
x_tensor,
cumsum_delta_tensor,
delta_tensor,
b_tensor,
c_tensor,
y_tensor,
fstate_tensor,
d_tensor,
stream,
# Launch compiled ssd kernel for reference check
if not skip_ref_check:
compiled_ssd(
x_tensor,
cumsum_delta_tensor,
delta_tensor,
b_tensor,
c_tensor,
y_tensor,
fstate_tensor,
d_tensor,
stream,
)
# Reference check
if print_rtol_stats:
print("\nY's Relative diffs:")
analyze_relative_diffs(
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
)
print("\nFstate's Relative diffs:")
analyze_relative_diffs(
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
)
torch.testing.assert_close(
y_torch.cpu(),
y_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-02,
)
torch.testing.assert_close(
fstate_torch.cpu(),
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-05,
)
def generate_tensors():
# Reuse existing CPU reference tensors and create new GPU tensors from them
_, x_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, C, L],
[2, 4, 3, 1, 0],
io_dtype,
ref_tensor=x_ref,
dynamic_modes=[2, 3, 4],
)
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
[B, EH, C, L],
[3, 2, 1, 0],
cumsum_delta_dtype,
ref_tensor=cumsum_delta_ref,
dynamic_modes=[1, 2, 3],
)
_, delta_tensor_new, _ = create_and_permute_tensor(
[B, EH, C, L],
[3, 2, 1, 0],
io_dtype,
ref_tensor=delta_ref,
dynamic_modes=[1, 2, 3],
)
_, b_tensor_new, _ = create_and_permute_tensor(
[B, G, N, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=b_ref,
dynamic_modes=[2, 3, 4],
)
_, c_tensor_new, _ = create_and_permute_tensor(
[B, G, N, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=c_ref,
dynamic_modes=[2, 3, 4],
)
_, y_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=y_ref,
dynamic_modes=[2, 3, 4],
)
_, fstate_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, N],
[2, 3, 1, 0],
io_dtype,
ref_tensor=fstate_ref,
dynamic_modes=[2, 3],
)
if has_d:
_, d_tensor_new, _ = create_and_permute_tensor(
[EH, D if d_has_hdim else 1],
[1, 0],
io_dtype,
ref_tensor=d_ref,
dynamic_modes=[1],
)
else:
d_tensor_new = d_tensor
return testing.JitArguments(
x_tensor_new,
cumsum_delta_tensor_new,
delta_tensor_new,
b_tensor_new,
c_tensor_new,
y_tensor_new,
fstate_tensor_new,
d_tensor_new,
stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
x_torch.numel() * x_torch.element_size()
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
+ delta_torch.numel() * delta_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
+ y_torch.numel() * y_torch.element_size()
+ fstate_torch.numel() * fstate_torch.element_size()
)
if has_d:
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_ssd,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# Reference check
if print_rtol_stats:
print("\nY's Relative diffs:")
analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)))
print("\nFstate's Relative diffs:")
analyze_relative_diffs(
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
)
torch.testing.assert_close(
y_torch.cpu(),
y_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-02,
)
torch.testing.assert_close(
fstate_torch.cpu(),
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-05,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -3586,15 +3707,53 @@ if __name__ == "__main__":
)
parser.add_argument(
"--ref_lower_precision",
type=bool,
action="store_true",
default=True,
help="Use lower precision for reference check",
)
parser.add_argument(
"--no-ref_lower_precision",
action="store_false",
dest="ref_lower_precision",
default=False,
help="Disable lower precision for reference check",
)
parser.add_argument(
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
)
parser.add_argument(
"--print_rtol_stats", type=bool, default=True, help="Print rtol stats"
"--print_rtol_stats",
action="store_true",
default=True,
help="Enable print rtol stats",
)
parser.add_argument(
"--no-print_rtol_stats",
action="store_false",
dest="print_rtol_stats",
default=False,
help="Disable print rtol stats",
)
parser.add_argument(
"--warmup_iterations",
type=int,
default=0,
help="Number of warmup iterations",
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -3602,18 +3761,18 @@ if __name__ == "__main__":
if len(args.gbehcdln) != 8:
parser.error("--gbehcdln must contain exactly 8 values")
has_d = args.fuse_scale_d != "none"
d_has_hdim = args.fuse_scale_d == "vector"
run_ssd(
run(
args.gbehcdln,
args.io_dtype,
args.cumsum_delta_dtype,
args.acc_dtype,
has_d,
d_has_hdim,
args.fuse_scale_d,
args.tolerance,
args.print_rtol_stats,
args.ref_lower_precision,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -35,6 +35,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
self.num_threads_per_warp_group = 128
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
self.ab_stage = None
self.epi_stage = None
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
}:
is_valid = False
# tested acc_dtype
if acc_dtype != cutlass.Float32:
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
is_valid = False
# tested c_dtype
if c_dtype not in {
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
return is_valid
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
@ -1347,9 +1366,43 @@ def run_dense_gemm(
tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
tolerance: float,
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param a_dtype: Data type for input tensor A
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: Data type for input tensor B
:type b_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param tile_shape_mnk: CTA tile shape (M, N, K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mn: Cluster shape (M, N)
:type cluster_shape_mn: Tuple[int, int]
:param tolerance: Tolerance value for reference validation comparison
:type tolerance: float
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Hopper Dense GEMM with:")
@ -1360,6 +1413,10 @@ def run_dense_gemm(
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
# Unpack parameters
m, n, k, l = mnkl
@ -1437,46 +1494,76 @@ def run_dense_gemm(
stream = cuda.CUstream(torch_stream.cuda_stream)
# compile gemm kernel
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
# execution
compiled_gemm(mA, mB, mC, stream)
torch.cuda.synchronize()
if not skip_ref_check:
# execution
compiled_gemm(mA, mB, mC, stream)
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
torch.cuda.synchronize()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
def generate_tensors():
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
args = parse_arguments()
run_dense_gemm(
run(
args.mnkl,
args.a_dtype,
args.b_dtype,
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
args.tile_shape_mnk,
args.cluster_shape_mn,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -399,6 +399,70 @@
"\n",
"tensor_print_example3()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def print_tensor_gpu(src: cute.Tensor):\n",
" print(src)\n",
" cute.print_tensor(src)\n",
"\n",
"@cute.jit\n",
"def print_tensor_host(src: cute.Tensor):\n",
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
" [[-0.690547, -0.274619, -1.659539, ],\n",
" [-1.843524, -1.648711, 1.163431, ],\n",
" [-0.716668, -1.900705, 0.592515, ],\n",
" [ 0.711333, -0.552422, 0.860237, ]])\n"
]
}
],
"source": [
"import torch\n",
"def tensor_print_example4():\n",
" a = torch.randn(4, 3, device=\"cuda\")\n",
" cutlass.cuda.initialize_cuda_context()\n",
" print_tensor_host(from_dlpack(a))\n",
"\n",
"tensor_print_example4()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
]
}
],
"metadata": {

View File

@ -256,16 +256,6 @@
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
"\n",
"@cute.kernel\n",
"def print_tensor_gpu(ptr: cute.Pointer):\n",
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
" tensor = cute.make_tensor(ptr, layout)\n",
"\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
"\n",
" if tidx == 0:\n",
" cute.print_tensor(tensor)\n",
"\n",
"\n",
"# Create a tensor with sequential data using torch\n",
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",

View File

@ -363,7 +363,7 @@
"| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n",
"| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n",
"|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
"| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
"| | \"optimized\" | Optimized for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
"|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
"| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n",
"\n",