v4.1 release update v2. (#2481)
This commit is contained in:
@ -64,7 +64,7 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla
|
||||
ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not
|
||||
enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do
|
||||
that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB
|
||||
to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C
|
||||
to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C
|
||||
which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the
|
||||
data type of output ElementOutput (int32_t), the number of elements per vector memory access (16),
|
||||
data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X +
|
||||
|
||||
@ -64,7 +64,7 @@ ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t)
|
||||
(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in
|
||||
memory, we have to convey the layout of matrices. We do that by initializing template variable
|
||||
LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row
|
||||
major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel.
|
||||
major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel.
|
||||
We initialize template variable EpilogueOp, which takes the data type of output ElementOutput
|
||||
(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t)
|
||||
and data type of computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
@ -66,7 +66,7 @@ ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB
|
||||
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
|
||||
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
@ -177,7 +177,7 @@ public:
|
||||
if(args.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
||||
// The user needs to call a reduction operator to optain the final output tensor
|
||||
// The user needs to call a reduction operator to obtain the final output tensor
|
||||
workspace_bytes =
|
||||
sizeof(ElementAccumulator) *
|
||||
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *
|
||||
|
||||
@ -153,7 +153,7 @@ struct Options {
|
||||
|
||||
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " run in a single kernel. Each individual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " back-to-back GEMMs are subject to.s"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
|
||||
@ -248,7 +248,7 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// Epilogue params remain constant across all problems in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
@ -402,7 +402,7 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// Epilogue params remain constant across all problems in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
@ -434,7 +434,7 @@ struct B2bGemm {
|
||||
// Only row-major outputs are currently supported, so no transpose is performed
|
||||
}
|
||||
|
||||
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||
/// Returns non-grouped parameters to be used as input to the kernel-level
|
||||
/// operator for the problem indicated by problem_visitor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||
|
||||
@ -560,7 +560,7 @@ struct DefaultB2bConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
|
||||
@ -606,7 +606,7 @@ struct DefaultB2bConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
|
||||
@ -277,7 +277,7 @@ public:
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
|
||||
@ -298,7 +298,7 @@ public:
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
|
||||
@ -93,7 +93,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
|
||||
@ -203,7 +203,7 @@ requires any memory for scratch space.
|
||||
If yes, we reserve scratch space and pass it along
|
||||
with other arguments to initialize the CUTLASS kernel.
|
||||
|
||||
After lauching the CUTLASS kernel, this example runs
|
||||
After launching the CUTLASS kernel, this example runs
|
||||
a reference convolution kernel (from CUTLASS utilities)
|
||||
to check correctness.
|
||||
*/
|
||||
|
||||
@ -144,7 +144,7 @@ int run() {
|
||||
// Construct Gemm ProblemSize with user defined output size
|
||||
cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024};
|
||||
|
||||
// Stride factor shows the distance between two elements in the differnet dimensions. The
|
||||
// Stride factor shows the distance between two elements in the different dimensions. The
|
||||
// first data is the logical distance between two rows, the second is between two columns.
|
||||
// CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory<Layout>::layout_factory
|
||||
// to help to convert stride_factor to the two strides.
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
// Define the overall warp-level problem shape
|
||||
int const kM = 27;
|
||||
int const kN = 31;
|
||||
int const kK = 17;
|
||||
|
||||
@ -59,7 +59,7 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
// Define the overall warp-level problem shape
|
||||
int const kM = 14;
|
||||
int const kN = 27;
|
||||
int const kK = 17;
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// This example fuses gather before GEMM and scatter after GEMM into the same
|
||||
// GEMM kernel. Gather and scatter operation is controled by an index vector
|
||||
// GEMM kernel. Gather and scatter operation is controlled by an index vector
|
||||
// to select rows or columns from A, B, C or D matrices.
|
||||
//
|
||||
// Suppose, all matrices are column major. The pseudo code of the fused kernel
|
||||
|
||||
@ -87,7 +87,7 @@ public:
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
// Pre-processing has ensured the layout equivelent to RowMajor
|
||||
// Pre-processing has ensured the layout equivalent to RowMajor
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
using TensorVariance = TensorRef<ElementVariance, Layout>;
|
||||
|
||||
@ -87,7 +87,7 @@ parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorC32RSK32"],
|
||||
|
||||
@ -86,7 +86,7 @@ parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
```
|
||||
|
||||
In practice, and for numerical stability reasons,
|
||||
we also substract the maximum so far (`mi`) before doing
|
||||
we also subtract the maximum so far (`mi`) before doing
|
||||
the exponential. When we encounter new keys, the maximum
|
||||
used to compute O so far (`m_prime`) can differ from the
|
||||
current maximum, so we update O before accumulating with
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
```
|
||||
|
||||
In practice, and for numerical stability reasons,
|
||||
we also substract the maximum so far (`mi`) before doing
|
||||
we also subtract the maximum so far (`mi`) before doing
|
||||
the exponential. When we encounter new keys, the maximum
|
||||
used to compute O so far (`m_prime`) can differ from the
|
||||
current maximum, so we update O before accumulating with
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
|
||||
/*! \file
|
||||
\brief Cutlass provides helper template functions to figure out the right
|
||||
datastructures to instanciate to run a GEMM with various parameters (see
|
||||
datastructures to instantiate to run a GEMM with various parameters (see
|
||||
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
|
||||
instantiation priority rules, it will only create an MmaMultiStage with
|
||||
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
|
||||
@ -83,7 +83,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
typename Enable_ = void>
|
||||
struct FindDefaultMma {
|
||||
|
||||
@ -522,7 +522,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
// For API compatibility with MmaMultistageFromSharedMemory
|
||||
// but not supported as it worsens perf: older gpus < sm80 don't
|
||||
// support async tranfers and have to waste registers
|
||||
// support async transfers and have to waste registers
|
||||
CUTLASS_DEVICE
|
||||
void set_prologue_done(bool value) {}
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Instanciates the right WarpIterator to read from shared memory
|
||||
\brief Instantiates the right WarpIterator to read from shared memory
|
||||
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
|
||||
data dumped with `B2bGemm::accumToSmem`.
|
||||
*/
|
||||
|
||||
@ -86,7 +86,7 @@ namespace threadblock {
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced
|
||||
/// at least once outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
|
||||
/// dereferencing the iterator.
|
||||
///
|
||||
///
|
||||
|
||||
@ -49,7 +49,7 @@
|
||||
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
|
||||
for this example:
|
||||
a_rows - Rows in the sparse matrix.
|
||||
a_cols - Colums in the sparse matrix.
|
||||
a_cols - Columns in the sparse matrix.
|
||||
a_ell_blocksize - Size of the ELL-Blocks.
|
||||
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
|
||||
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
|
||||
|
||||
@ -153,7 +153,7 @@ class gen_device:
|
||||
|
||||
warp_M_tile = 32
|
||||
|
||||
# Determine maxmimum N_tile
|
||||
# Determine maximum N_tile
|
||||
Max_Ntile = 0
|
||||
for layer in self.fuse_gemm_info:
|
||||
n_tile = layer['mnk'][1]
|
||||
|
||||
@ -76,9 +76,9 @@ class gen_verify:
|
||||
)
|
||||
|
||||
|
||||
def get_params(self, declartion = True):
|
||||
def get_params(self, declaration = True):
|
||||
code = ""
|
||||
if declartion:
|
||||
if declaration:
|
||||
for param in self.params:
|
||||
code += param[0] + " " + param[1] + ";\n"
|
||||
|
||||
|
||||
@ -64,8 +64,8 @@ def write_2_headfile(filename, file_dir, string):
|
||||
with open(file_dir + filename, 'w') as f:
|
||||
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
|
||||
|
||||
def var_idx(varaiable, index):
|
||||
return varaiable + str(index)
|
||||
def var_idx(variable, index):
|
||||
return variable + str(index)
|
||||
|
||||
|
||||
def list_2_string(input_list, ):
|
||||
|
||||
@ -78,7 +78,7 @@
|
||||
a single default value.
|
||||
|
||||
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
|
||||
the collective is specified via the schedule tags that corresond to the underlying collective's
|
||||
the collective is specified via the schedule tags that correspond to the underlying collective's
|
||||
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
|
||||
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
|
||||
therefore letting the builder pick the collective specialization.
|
||||
|
||||
@ -425,7 +425,7 @@ int main(int argc, char const **args) {
|
||||
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
|
||||
constexpr int PipelineStages = 8;
|
||||
|
||||
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
|
||||
// Let's choose a Warp-Specialized Mainloop implementation which uses TMA
|
||||
// Note : This requires / assumes the tensors to be 16B aligned
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>;
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
|
||||
|
||||
This example fuses gather before GEMM and scatter after GEMM into the same
|
||||
GEMM kernel. Gather and scatter operation is controled by an index vector
|
||||
GEMM kernel. Gather and scatter operation is controlled by an index vector
|
||||
to select rows or columns from A, B, C or D matrices.
|
||||
|
||||
Gather/scatter operations are always performed along a strided dimension
|
||||
|
||||
@ -65,7 +65,7 @@
|
||||
The approach relies on two things:
|
||||
- The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the
|
||||
flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details).
|
||||
- The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
|
||||
- The hardware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
|
||||
(almost) arbitrary strides, which can be used to represent a permuted view of the data.
|
||||
|
||||
In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags.
|
||||
|
||||
@ -188,7 +188,7 @@ Running this example on an RTX 3080Ti prints the following performance numbers (
|
||||
|
||||
```
|
||||
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
|
||||
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
|
||||
Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.
|
||||
|
||||
Allocating tensors ... done.
|
||||
Initializing data ... done.
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
|
||||
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel
|
||||
capable of operating on both affine and gather/scatter tensors.
|
||||
|
||||
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
|
||||
@ -284,7 +284,7 @@ int ampere_gather_scatter_conv_fprop(
|
||||
int
|
||||
main(int argc, char const** argv) {
|
||||
cutlass::CommandLine cmd(argc, argv);
|
||||
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
|
||||
std::cout << "Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.\n\n";
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
std::cout
|
||||
<< "Options:\n"
|
||||
|
||||
@ -291,7 +291,7 @@ struct Options {
|
||||
// Post-process the problem sizes
|
||||
bin_problems();
|
||||
|
||||
// Initalize alpha array
|
||||
// Initialize alpha array
|
||||
randomize_alpha_ptr_array(cmd);
|
||||
}
|
||||
|
||||
|
||||
@ -358,7 +358,7 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
// Layout SFA and SFB represent logically broadcasting data in CuTe.
|
||||
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
|
||||
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
|
||||
// indecies in the tensor map to the same offset.
|
||||
// indices in the tensor map to the same offset.
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
@ -61,7 +61,7 @@
|
||||
# Heuristic mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
|
||||
|
||||
# Stream-K mode with determinsitic reduction
|
||||
# Stream-K mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
|
||||
|
||||
# Split-K mode with a splitting factor of 2 and deterministic reduction
|
||||
|
||||
@ -850,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << " Verfication is turned off for this run." << std::endl;
|
||||
std::cout << " Verification is turned off for this run." << std::endl;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
Xformed Activation (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Activation (NZPQK)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
Xformed Activation (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter
|
||||
|
||||
@ -505,8 +505,12 @@ struct FwdRunner {
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
|
||||
|
||||
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
|
||||
|
||||
This example showcases the use of CUTLASS to build backward fused
|
||||
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
the NVIDIA Blackwell architecture.
|
||||
|
||||
Background and motivation
|
||||
@ -117,6 +117,7 @@ struct Options {
|
||||
std::vector<int> varlen_q;
|
||||
std::vector<int> varlen_k;
|
||||
int d = 128;
|
||||
int d_vo = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
@ -178,6 +179,7 @@ struct Options {
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("d_vo", d_vo, d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
@ -301,6 +303,7 @@ struct Options {
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
|
||||
<< " --d=<int> Sets the D extent\n"
|
||||
<< " --d_vo=<int> Sets the D_VO extent\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
@ -387,6 +390,7 @@ struct ExampleResult {
|
||||
|
||||
template<
|
||||
bool kIsVarlen,
|
||||
bool kIsMla,
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
@ -404,8 +408,8 @@ struct BwdRunner {
|
||||
// Q K D (H B)
|
||||
using ProblemShape = std::conditional_t<
|
||||
kIsVarlen,
|
||||
cute::tuple<VariableLength, VariableLength, int, cute::tuple<int, int>>,
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>>
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<int, int>>
|
||||
>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
@ -461,45 +465,45 @@ struct BwdRunner {
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
@ -595,14 +599,14 @@ struct BwdRunner {
|
||||
ProblemShape problem_shape{
|
||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
||||
options.d, {options.h, options.b}
|
||||
options.d, options.d_vo, {options.h, options.b}
|
||||
};
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, make_shape(options.h, 1));
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
|
||||
|
||||
return cute::make_tuple(problem_shape, tensor_shape);
|
||||
}
|
||||
else {
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, {options.h, options.b}};
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
|
||||
return cute::make_tuple(problem_shape, problem_shape);
|
||||
}
|
||||
}
|
||||
@ -610,24 +614,25 @@ struct BwdRunner {
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
ProblemShape initialize(Options const& options) {
|
||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
||||
auto [Q, K, D, HB] = tensor_shape;
|
||||
auto [Q, K, D, D_VO, HB] = tensor_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
|
||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
||||
|
||||
auto shape_QO = make_shape(Q, D, make_shape(H, B));
|
||||
auto shape_KV = make_shape(K, D, make_shape(H, B));
|
||||
auto shape_Q = make_shape(Q, D, make_shape(H, B));
|
||||
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
|
||||
auto shape_K = make_shape(K, D, make_shape(H, B));
|
||||
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
|
||||
auto shape_LSE = make_shape(Q, make_shape(H, B));
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
|
||||
|
||||
stride_V = stride_K;
|
||||
stride_O = stride_Q;
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
stride_dV = stride_V;
|
||||
@ -637,20 +642,20 @@ struct BwdRunner {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
block_Q.reset(lsize(shape_QO));
|
||||
block_K.reset(lsize(shape_KV));
|
||||
block_V.reset(lsize(shape_KV));
|
||||
block_O.reset(lsize(shape_QO));
|
||||
block_Q.reset(lsize(shape_Q));
|
||||
block_K.reset(lsize(shape_K));
|
||||
block_V.reset(lsize(shape_V));
|
||||
block_O.reset(lsize(shape_O));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_QO));
|
||||
block_dK.reset(lsize(shape_KV));
|
||||
block_dV.reset(lsize(shape_KV));
|
||||
block_dO.reset(lsize(shape_QO));
|
||||
block_dQ.reset(lsize(shape_Q));
|
||||
block_dK.reset(lsize(shape_K));
|
||||
block_dV.reset(lsize(shape_V));
|
||||
block_dO.reset(lsize(shape_O));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_QO));
|
||||
block_ref_dK.reset(lsize(shape_KV));
|
||||
block_ref_dV.reset(lsize(shape_KV));
|
||||
block_ref_dQ.reset(lsize(shape_Q));
|
||||
block_ref_dK.reset(lsize(shape_K));
|
||||
block_ref_dV.reset(lsize(shape_V));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
@ -665,23 +670,23 @@ struct BwdRunner {
|
||||
initialize_block(block_ref_dV, seed + 2035);
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
@ -698,7 +703,7 @@ struct BwdRunner {
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
@ -811,12 +816,12 @@ struct BwdRunner {
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= static_cast<double>(get<2>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,1>(problem_shape));
|
||||
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
|
||||
flops *= static_cast<double>(get<4,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
@ -892,7 +897,7 @@ template<class Mask>
|
||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
BwdRunner<decltype(is_varlen)::value, false,decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
@ -900,7 +905,7 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -909,7 +914,7 @@ template<class Mask>
|
||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
BwdRunner<decltype(is_varlen)::value, false, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
@ -917,7 +922,22 @@ void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_mla_192(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, true, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
};
|
||||
|
||||
using HeadDim = _192;
|
||||
|
||||
run(Shape<_64, _128, HeadDim, _128>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -981,7 +1001,7 @@ int main_single(int argc, char const **args) {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
@ -998,12 +1018,15 @@ int main_single(int argc, char const **args) {
|
||||
};
|
||||
|
||||
with_causal([&](auto fusion) {
|
||||
if (options.d <= 64) {
|
||||
if (options.d <= 64 && options.d_vo == options.d) {
|
||||
run_bwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
else if (options.d <= 128 && options.d_vo == options.d) {
|
||||
run_bwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d == 192 && options.d_vo == 128) {
|
||||
run_bwd_mla_192(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
|
||||
@ -485,7 +485,11 @@ struct MlaFwdRunner {
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
|
||||
|
||||
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
@ -84,6 +84,8 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
|
||||
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
@ -174,6 +176,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
TEST_BWD_MLA_BASIC
|
||||
TEST_BWD_MLA_VARLEN
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
@ -37,13 +37,19 @@ There are three kernels to compute backwards:
|
||||
|
||||
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
|
||||
|
||||
## MLA Blackwell Backward
|
||||
|
||||
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
|
||||
|
||||
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
|
||||
|
||||
# MLA Inference for Blackwell
|
||||
|
||||
This sample provides code for fused multi-head latent attention inference in
|
||||
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
|
||||
It supports fp16, bf16, and fp8 input and output types.
|
||||
|
||||
To accomodate the large output accumulator due to the large latent head dimension,
|
||||
To accommodate the large output accumulator due to the large latent head dimension,
|
||||
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
|
||||
|
||||
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
|
||||
#include "../device/fmha.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
|
||||
@ -55,13 +56,14 @@ template<
|
||||
class Element,
|
||||
class ElementAccumulator,
|
||||
class TileShape,
|
||||
bool IsMla,
|
||||
class Mask
|
||||
>
|
||||
class Sm100FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
// Q K D HB
|
||||
// Q K D D_VO HB
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_Q;
|
||||
@ -98,11 +100,20 @@ public:
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
using OperationNormal= cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
using OperationMla = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
|
||||
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
struct Params {
|
||||
@ -121,7 +132,7 @@ private:
|
||||
ElementAccumulator* sum_odo = nullptr,
|
||||
ElementAccumulator* scaled_lse = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
@ -141,7 +152,7 @@ private:
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
@ -163,6 +174,7 @@ private:
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
|
||||
return typename Operation::Arguments{
|
||||
args.problem_shape,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
@ -207,7 +219,7 @@ public:
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
@ -227,7 +239,7 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
@ -256,7 +268,7 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, HB] = args.problem_shape;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
|
||||
@ -85,11 +85,11 @@ struct FmhaKernelBwdConvert {
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<3,0>(params.problem_shape), size<3,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
|
||||
dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
@ -103,7 +103,7 @@ struct FmhaKernelBwdConvert {
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest, class Count>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count) {
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
@ -120,7 +120,7 @@ struct FmhaKernelBwdConvert {
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAcc value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
@ -139,13 +139,13 @@ struct FmhaKernelBwdConvert {
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape));
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -86,11 +86,11 @@ struct FmhaKernelBwdSumOdO {
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<3,0>(params.problem_shape), size<3,1>(params.problem_shape));
|
||||
dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape));
|
||||
return grid;
|
||||
}
|
||||
|
||||
@ -131,7 +131,7 @@ struct FmhaKernelBwdSumOdO {
|
||||
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
|
||||
@ -344,12 +344,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
auto [Q, K, D, HB] = args.problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0) {
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (D % Alignment != 0) {
|
||||
if (D % Alignment != 0 || D_VO % Alignment != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -362,7 +362,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void*) {
|
||||
auto [Q_, K_, D, HB] = args.problem_shape;
|
||||
auto [Q_, K_, D, D_VO, HB] = args.problem_shape;
|
||||
int Q = Q_;
|
||||
int K = K_;
|
||||
|
||||
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}, /*workspace=*/nullptr);
|
||||
|
||||
auto params_vdo = CollectiveMmaVDO::to_underlying_arguments(
|
||||
make_shape(K, Q, D, HB),
|
||||
make_shape(K, Q, D_VO, HB),
|
||||
typename CollectiveMmaVDO::Arguments {
|
||||
args.mainloop.ptr_v, args.mainloop.stride_v,
|
||||
args.mainloop.ptr_do, args.mainloop.stride_do,
|
||||
@ -446,21 +446,21 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo,
|
||||
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
uint16_t mcast_mask = 0;
|
||||
|
||||
auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB));
|
||||
auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB));
|
||||
auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB));
|
||||
|
||||
auto mK = domain_offset(select<1,2,3>(blk_offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(blk_offset), mV_in);
|
||||
auto mQ = domain_offset(select<0,2,3>(blk_offset), mQ_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(blk_offset), mDO_in);
|
||||
auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in);
|
||||
auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in);
|
||||
|
||||
auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{});
|
||||
auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{});
|
||||
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
// set up lse and sum_odo
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
@ -681,7 +681,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
|
||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||
auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{});
|
||||
@ -974,11 +974,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueArguments const& epilogue_args) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
||||
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
|
||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
@ -988,7 +988,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
|
||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
||||
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
|
||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
@ -1003,7 +1003,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
}
|
||||
for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) {
|
||||
if (elem_less(cDV(i), select<1,2>(problem_shape))) {
|
||||
if (elem_less(cDV(i), select<1,3>(problem_shape))) {
|
||||
gDV(i) = Element(0);
|
||||
}
|
||||
}
|
||||
@ -1020,8 +1020,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
auto load_op = SM100_TMEM_LOAD_32dp32b16x{};
|
||||
|
||||
@ -1029,7 +1029,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tDKtDK.data() = TmemAllocation::kDK;
|
||||
|
||||
auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk);
|
||||
auto mDK = domain_offset(select<1,2,3>(blk_offset), mDK_in);
|
||||
auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in);
|
||||
auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
@ -1065,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tDVtDV.data() = TmemAllocation::kDV;
|
||||
|
||||
auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv);
|
||||
auto mDV = domain_offset(select<1,2,3>(blk_offset), mDV_in);
|
||||
auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in);
|
||||
auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
(_, _, blk_coord_k, _0{}, blk_coord_batch);
|
||||
|
||||
@ -1088,7 +1088,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV);
|
||||
|
||||
// store tDVgDV
|
||||
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,2>(problem_shape));
|
||||
store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape));
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state);
|
||||
@ -1140,7 +1140,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
// and dP and dQ overlap
|
||||
@ -1396,9 +1396,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch] = blk_coord;
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
// must match TileShapeDQ
|
||||
auto load_op = SM100_TMEM_LOAD_32dp32b32x{};
|
||||
@ -1676,7 +1676,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
pipeline_init_wait(size(ClusterShape{}));
|
||||
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto [problem_shape, blk_offset] = apply_variable_length_offset(
|
||||
params.problem_shape,
|
||||
blk_coord
|
||||
@ -1809,7 +1809,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
auto [Q, K, D, HB] = params.problem_shape;
|
||||
auto [Q, K, D, D_VO, HB] = params.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
|
||||
return grid;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -33,7 +33,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
using namespace cutlass::fmha::collective;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
@ -61,20 +63,20 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
make_coord(_0{}, _0{}, _0{}, _0{},idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDQ = domain_offset(select<0,2,3>(offset), mDQ_in);
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
@ -82,10 +84,15 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
} // for idx_D0
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
|
||||
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
@ -135,20 +142,20 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDK = domain_offset(select<1,2,3>(offset), mDK_in);
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
@ -156,10 +163,14 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
} // for idx_D0
|
||||
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
|
||||
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
|
||||
}
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
@ -209,20 +220,20 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
|
||||
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, idx2crd(idx_L, get<3>(problem_shape_in)))
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,3>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,3>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,2,3>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,2,3>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,3>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,2,3>(offset), mDO_in);
|
||||
auto mDV = domain_offset(select<1,2,3>(offset), mDV_in);
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
@ -244,7 +255,7 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
ElementAcc rS = static_cast<Element>(mS[idx_Q]);
|
||||
|
||||
@ -62,19 +62,20 @@ void __global__ fmha_reference_kernel(
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mQ)));
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) {
|
||||
|
||||
auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in));
|
||||
auto coord_L = idx2crd(idx_L, shape<4>(problem_shape_in));
|
||||
auto get_coord_in = [&]() {
|
||||
if constexpr (rank_v<decltype(get<2>(ProblemShapeIn{}))> == 2) {
|
||||
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), coord_L);
|
||||
return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), cute::make_tuple(_0{}, _0{}), coord_L);
|
||||
} else {
|
||||
return cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
|
||||
return cute::make_tuple(idx_Q, _0{}, _0{}, _0{}, coord_L);
|
||||
}
|
||||
};
|
||||
auto coord_in = get_coord_in();
|
||||
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in));
|
||||
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<4,1>(coord_in));
|
||||
|
||||
int head_qk = 0;
|
||||
int head_v = 0;
|
||||
@ -83,7 +84,7 @@ void __global__ fmha_reference_kernel(
|
||||
head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape);
|
||||
head_v = size<2, 0>(problem_shape);
|
||||
} else {
|
||||
head_qk = size<2>(problem_shape);
|
||||
head_qk = size<3>(problem_shape);
|
||||
head_v = head_qk;
|
||||
}
|
||||
|
||||
@ -157,6 +158,7 @@ void __global__ fmha_reference_kernel(
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||
}
|
||||
|
||||
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||
}
|
||||
|
||||
@ -835,7 +835,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << " Verfication is turned off for this run." << std::endl;
|
||||
std::cout << " Verification is turned off for this run." << std::endl;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
|
||||
@ -259,7 +259,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
@ -394,7 +394,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
// The MMA's partitioning then yields the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
|
||||
@ -295,7 +295,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
@ -433,7 +433,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
// The MMA's partitioning then yields the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
|
||||
@ -333,7 +333,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
@ -471,7 +471,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
// The MMA's partitioning then yields the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
|
||||
@ -328,7 +328,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
@ -473,7 +473,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
// The MMA's partitioning then yields the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
|
||||
@ -341,7 +341,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
|
||||
|
||||
// Step 2: The Mainloop.
|
||||
|
||||
// Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
|
||||
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
|
||||
@ -527,7 +527,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
|
||||
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
|
||||
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
|
||||
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
|
||||
// The MMA's partitioning then yeilds the CTA-local work.
|
||||
// The MMA's partitioning then yields the CTA-local work.
|
||||
|
||||
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
|
||||
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
|
||||
|
||||
@ -200,7 +200,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// Note, this assumes the vector and thread layouts are aligned with contiguous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative value layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
|
||||
@ -90,18 +90,17 @@ If you already know the TV layout you want to use for your tiled copy, CuTe DSL
|
||||
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
|
||||
gA = cute.zipped_divide(mA, tiler_mn)
|
||||
|
||||
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
|
||||
thread index and inter-thread index of data array per thread to logical coordinates of elements in
|
||||
input and output tensors.
|
||||
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which
|
||||
infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread
|
||||
block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per
|
||||
thread to logical coordinates of elements in input and output tensors.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
blkA = gA[((None, None), bidx)] # (TileM,TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
|
||||
# get slice of tiled_copy_A for current thread
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
@ -140,8 +139,8 @@ def elementwise_add_kernel(
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout,
|
||||
tiler_mn: cute.Shape,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
@ -165,9 +164,9 @@ def elementwise_add_kernel(
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
@ -254,7 +253,7 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
@ -362,7 +361,7 @@ def run_elementwise_add(
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
|
||||
@ -353,7 +353,7 @@ def run_elementwise_apply_and_verify(
|
||||
current_stream,
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
use_cuda_graphs=True,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
@ -32,13 +32,13 @@ from typing import Type, Union, Callable
|
||||
|
||||
import torch
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync, warp
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.utils.ampere_helpers as sm80_utils
|
||||
import cutlass.utils as utils
|
||||
|
||||
"""
|
||||
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
|
||||
@ -163,7 +163,7 @@ class FlashAttentionForwardAmpere:
|
||||
# Check if block size setting is out of shared memory capacity
|
||||
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
|
||||
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
|
||||
smem_capacity = sm80_utils.SMEM_CAPACITY["sm80"]
|
||||
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
|
||||
if smem_usage > smem_capacity:
|
||||
return False
|
||||
|
||||
@ -469,21 +469,9 @@ class FlashAttentionForwardAmpere:
|
||||
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
|
||||
self._dtype,
|
||||
)
|
||||
smem_tiled_copy_Q = cute.make_tiled_copy(
|
||||
smem_copy_atom_Q,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_K = cute.make_tiled_copy(
|
||||
smem_copy_atom_K,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_V = cute.make_tiled_copy(
|
||||
smem_copy_atom_V,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
|
||||
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
|
||||
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
|
||||
|
||||
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
|
||||
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
|
||||
@ -702,11 +690,7 @@ class FlashAttentionForwardAmpere:
|
||||
cute.nvgpu.CopyUniversalOp(), self._dtype
|
||||
)
|
||||
# tiled copy atom for O
|
||||
smem_tiled_copy_O = cute.make_tiled_copy(
|
||||
smem_copy_atom_O,
|
||||
layout_tv=tiled_mma.tv_layout_C_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
|
||||
)
|
||||
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
|
||||
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
|
||||
taccOrO = smem_thr_copy_O.retile(rO)
|
||||
taccOsO = smem_thr_copy_O.partition_D(sO)
|
||||
@ -1178,7 +1162,7 @@ class FlashAttentionForwardAmpere:
|
||||
return cute.arch.exp2(x)
|
||||
|
||||
|
||||
def run_flash_attention_fwd(
|
||||
def run(
|
||||
dtype: Type[cutlass.Numeric],
|
||||
batch_size: int,
|
||||
seqlen_q: int,
|
||||
@ -1193,6 +1177,8 @@ def run_flash_attention_fwd(
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Skip unsupported testcase
|
||||
if not FlashAttentionForwardAmpere.can_implement(
|
||||
@ -1207,6 +1193,23 @@ def run_flash_attention_fwd(
|
||||
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
|
||||
)
|
||||
|
||||
print(f"Running Ampere SM80 FlashAttentionForward test with:")
|
||||
print(f" dtype: {dtype}")
|
||||
print(f" batch_size: {batch_size}")
|
||||
print(f" seqlen_q: {seqlen_q}")
|
||||
print(f" seqlen_k: {seqlen_k}")
|
||||
print(f" num_head: {num_head}")
|
||||
print(f" head_dim: {head_dim}")
|
||||
print(f" softmax_scale: {softmax_scale}")
|
||||
print(f" m_block_size: {m_block_size}")
|
||||
print(f" n_block_size: {n_block_size}")
|
||||
print(f" num_threads: {num_threads}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Create tensor Q/K/V/O
|
||||
def create_tensor(
|
||||
batch_size: int,
|
||||
@ -1217,22 +1220,28 @@ def run_flash_attention_fwd(
|
||||
) -> cute.Tensor:
|
||||
# (batch_size, seqlen, num_head, head_dim)
|
||||
shape = (batch_size, seqlen, num_head, head_dim)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32).random_(-2, 2).to(dtype=dtype).cuda()
|
||||
torch_tensor = (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=cutlass_torch.dtype(dtype))
|
||||
.cuda()
|
||||
)
|
||||
# assume input is 16B aligned.
|
||||
cute_tensor = (
|
||||
from_dlpack(torch_tensor, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3,
|
||||
stride_order=torch_tensor.dim_order(),
|
||||
divisibility=(128 // dtype.width),
|
||||
)
|
||||
)
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
q = create_tensor(
|
||||
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
k = create_tensor(
|
||||
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
v = create_tensor(
|
||||
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
o = create_tensor(
|
||||
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
|
||||
)
|
||||
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
|
||||
fa2_fwd = FlashAttentionForwardAmpere(
|
||||
head_dim,
|
||||
@ -1241,78 +1250,63 @@ def run_flash_attention_fwd(
|
||||
num_threads,
|
||||
is_causal,
|
||||
)
|
||||
# assume input is 16B align.
|
||||
q_tensor = (
|
||||
from_dlpack(q, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=q.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
k_tensor = (
|
||||
from_dlpack(k, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=k.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
v_tensor = (
|
||||
from_dlpack(v, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=v.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
o_tensor = (
|
||||
from_dlpack(o, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=3)
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=3, stride_order=o.dim_order(), divisibility=(128 // dtype.width)
|
||||
)
|
||||
)
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# compile the fa2 forward pass
|
||||
compiled_fa2_fwd = cute.compile(
|
||||
fa2_fwd, q_tensor, k_tensor, v_tensor, o_tensor, softmax_scale, current_stream
|
||||
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
|
||||
torch.cuda.synchronize()
|
||||
q_ref = q_torch.permute(0, 2, 1, 3)
|
||||
k_ref = k_torch.permute(0, 2, 1, 3)
|
||||
v_ref = v_torch.permute(0, 2, 1, 3)
|
||||
torch.backends.cuda.enable_flash_sdp(enabled=True)
|
||||
ref_o = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
|
||||
).permute(0, 2, 1, 3)
|
||||
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
|
||||
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
|
||||
return testing.JitArguments(
|
||||
q_workspace,
|
||||
k_workspace,
|
||||
v_workspace,
|
||||
o_workspace,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
q_torch.numel() * q_torch.element_size()
|
||||
+ k_torch.numel() * k_torch.element_size()
|
||||
+ v_torch.numel() * v_torch.element_size()
|
||||
+ o_torch.numel() * o_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_fa2_fwd,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
# warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fa2_fwd(
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
# run the compiled fa2 forward pass
|
||||
for _ in range(iterations):
|
||||
compiled_fa2_fwd(
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if skip_ref_check:
|
||||
return
|
||||
# reference implementation
|
||||
q_ref = q.permute(0, 2, 1, 3)
|
||||
k_ref = k.permute(0, 2, 1, 3)
|
||||
v_ref = v.permute(0, 2, 1, 3)
|
||||
torch.backends.cuda.enable_flash_sdp(enabled=True)
|
||||
ref_o = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
|
||||
).permute(0, 2, 1, 3)
|
||||
|
||||
torch.testing.assert_close(o.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -1334,9 +1328,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference check"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_flash_attention_fwd(
|
||||
run(
|
||||
args.dtype,
|
||||
args.batch_size,
|
||||
args.seqlen_q,
|
||||
@ -1348,6 +1348,10 @@ if __name__ == "__main__":
|
||||
args.n_block_size,
|
||||
args.num_threads,
|
||||
args.is_causal,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
@ -634,16 +634,50 @@ class SGemm:
|
||||
return
|
||||
|
||||
|
||||
def main(
|
||||
def run(
|
||||
mnk: Tuple[int, int, int],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
problem_shape: Tuple[int, int, int],
|
||||
static_shape: bool = False,
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
M, N, K = problem_shape
|
||||
"""Execute SIMT GEMM operation and benchmark performance.
|
||||
|
||||
:param mnk: GEMM problem size (M, N, K, L)
|
||||
:type mnk: Tuple[int, int, int, int]
|
||||
:param a_major: Memory layout of tensor A
|
||||
:type a_major: str
|
||||
:param b_major: Memory layout of tensor B
|
||||
:type b_major: str
|
||||
:param c_major: Memory layout of tensor C
|
||||
:type c_major: str
|
||||
:param static_shape: Whether to use static shape optimization, defaults to False
|
||||
:type static_shape: bool, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 100
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Skip validation against reference implementation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Ampere SIMT GEMM example:")
|
||||
print(f"mnk: {mnk}")
|
||||
print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}")
|
||||
print(f"Static shape: {static_shape}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
M, N, K = mnk
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
|
||||
@ -710,20 +744,6 @@ def main(
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(
|
||||
a_tensor, b_tensor, c_tensor, current_stream
|
||||
),
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
stream=current_stream,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
torch.cuda.synchronize()
|
||||
@ -732,6 +752,71 @@ def main(
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
# Create new tensors for each workspace to ensure cold L2 cache
|
||||
a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
|
||||
b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
|
||||
c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
|
||||
|
||||
if static_shape:
|
||||
a_tensor_workspace = (
|
||||
from_dlpack(a_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
divisibility=divisibility_a,
|
||||
)
|
||||
)
|
||||
else:
|
||||
a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16)
|
||||
|
||||
b_tensor_workspace = (
|
||||
from_dlpack(b_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
divisibility=divisibility_b,
|
||||
)
|
||||
)
|
||||
|
||||
c_tensor_workspace = (
|
||||
from_dlpack(c_workspace, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
divisibility=divisibility_c,
|
||||
)
|
||||
)
|
||||
|
||||
return testing.JitArguments(
|
||||
a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a.numel() * a.element_size()
|
||||
+ b.numel() * b.element_size()
|
||||
+ c.numel() * c.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -753,19 +838,27 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running SIMT GEMM example:")
|
||||
|
||||
torch.manual_seed(1024)
|
||||
|
||||
main(
|
||||
run(
|
||||
args.mnk,
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.mnk,
|
||||
args.static_shape,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Threadblock rasterization to improve data re-use
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesced global memory access
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
@ -214,7 +214,7 @@ class TensorOpGemm:
|
||||
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchonous copy atom and thread layouts for the epilogue
|
||||
# Creates a synchronous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
@ -550,16 +550,8 @@ class TensorOpGemm:
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
@ -836,8 +828,7 @@ class TensorOpGemm:
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def raster_tile(self, i, j, f):
|
||||
new_i = i // f
|
||||
@ -845,20 +836,33 @@ class TensorOpGemm:
|
||||
return (new_i, new_j)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
def run(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
problem_shape: Tuple[int, int, int, int],
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
M, N, K, L = problem_shape
|
||||
print(f"Running Ampere tensor core GEMM example:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
print(
|
||||
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
||||
)
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Atoms layout: {atom_layout_mnk}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
M, N, K, L = mnkl
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
||||
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
return (
|
||||
torch_tensor = (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype)
|
||||
.to(dtype=cutlass_torch.dtype(dtype))
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
# assume input is 16B aligned
|
||||
cute_tensor = (
|
||||
from_dlpack(torch_tensor, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if not is_mode0_major else 0),
|
||||
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
||||
divisibility=(128 // dtype.width),
|
||||
)
|
||||
)
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
a = create_and_permute_tensor(
|
||||
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
b = create_and_permute_tensor(
|
||||
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
|
||||
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
ab_dtype,
|
||||
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
|
||||
atom_layout_mnk,
|
||||
)
|
||||
|
||||
# assume input is 16B aligned
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
|
||||
divisibility=(128 // c_dtype.width),
|
||||
)
|
||||
)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
|
||||
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
a_torch.to(dtype=torch.float32),
|
||||
b_torch.to(dtype=torch.float32),
|
||||
).to(cutlass_torch.dtype(c_dtype))
|
||||
compiled_gemm(mA, mB, mC)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
||||
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
||||
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
gemm,
|
||||
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
warmup_iterations=warmup_iterations,
|
||||
profiling_iterations=iterations,
|
||||
iterations=iterations,
|
||||
use_cuda_graphs=False,
|
||||
)
|
||||
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
|
||||
if not skip_ref_check:
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
return avg_time_us # Return execution time in microseconds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -985,10 +987,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running Ampere tensor core GEMM example:")
|
||||
run_tensor_op_gemm(
|
||||
run(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -212,7 +212,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1106,11 +1106,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1772,7 +1768,7 @@ def run_dense_gemm(
|
||||
ref_c = ref
|
||||
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
|
||||
@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
|
||||
self.cta_sync_bar_id = 0
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
|
||||
return can_implement
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,17 +1829,58 @@ def run_dense_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
use_2cta_instrs: bool,
|
||||
use_tma_store: bool,
|
||||
tolerance: float,
|
||||
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
||||
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
||||
use_2cta_instrs: bool = True,
|
||||
use_tma_store: bool = True,
|
||||
tolerance: float = 1e-01,
|
||||
warmup_iterations: int = 0,
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
||||
|
||||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||||
optionally performs reference validation, and benchmarks the execution performance.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param ab_dtype: Data type for input tensors A and B
|
||||
:type ab_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type mma_tiler_mn: Tuple[int, int], optional
|
||||
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
||||
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type cluster_shape_mn: Tuple[int, int], optional
|
||||
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
||||
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_2cta_instrs: bool, optional
|
||||
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
|
||||
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
||||
:type use_tma_store: bool, optional
|
||||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||||
:type tolerance: float, optional
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:raises RuntimeError: If CUDA GPU is not available
|
||||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||||
:return: Execution time of the GEMM kernel
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Persistent Dense GEMM test with:")
|
||||
print(f"mnkl: {mnkl}")
|
||||
@ -1855,6 +1893,7 @@ def run_dense_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1931,15 +1970,15 @@ def run_dense_gemm(
|
||||
is_dynamic_layout=is_dynamic_layout,
|
||||
)
|
||||
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor
|
||||
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
||||
|
||||
a_ref, a_tensor, a_torch = create_and_permute_tensor(
|
||||
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
|
||||
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
b_ref, b_tensor, b_torch = create_and_permute_tensor(
|
||||
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
|
||||
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
||||
)
|
||||
c_ref, c_tensor, c_torch = create_and_permute_tensor(
|
||||
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
|
||||
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
|
||||
)
|
||||
|
||||
@ -1967,16 +2006,8 @@ def run_dense_gemm(
|
||||
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for i in range(warmup_iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
if ab_dtype in {
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
@ -2028,6 +2059,40 @@ def run_dense_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
c_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
)
|
||||
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
||||
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
||||
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -223,7 +223,7 @@ class DenseGemmKernel:
|
||||
|
||||
self.occupancy = 1
|
||||
self.threads_per_cta = 128
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
@ -1063,11 +1063,7 @@ class DenseGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
@ -43,6 +43,7 @@ import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
|
||||
|
||||
@ -90,7 +91,7 @@ Constraints for this example:
|
||||
* Number of heads in Q must be divisible by number of heads in K
|
||||
* mma_tiler_mn must be 128,128
|
||||
* Batch size must be the same for Q, K, and V tensors
|
||||
* For causal masking, use --has_casual_mask (note: specify without =True/False)
|
||||
* For causal masking, use --is_causal (note: specify without =True/False)
|
||||
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
|
||||
"""
|
||||
|
||||
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
smem_copy_atom = sm100_utils.get_smem_store_op(
|
||||
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy(
|
||||
smem_copy_atom,
|
||||
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_tmem_load.tiler_mn,
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
||||
|
||||
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
||||
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
def run_fmha_and_verify(
|
||||
def run(
|
||||
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
in_dtype: Type[cutlass.Numeric],
|
||||
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
|
||||
pv_acc_dtype: Type[cutlass.Numeric],
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
is_persistent: bool,
|
||||
has_casual_mask: bool,
|
||||
is_causal: bool,
|
||||
scale_q: float,
|
||||
scale_k: float,
|
||||
scale_v: float,
|
||||
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
|
||||
|
||||
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param is_persistent: Whether to use persistent kernel optimization
|
||||
:type is_persistent: bool
|
||||
:param has_casual_mask: Whether to apply causal masking
|
||||
:type has_casual_mask: bool
|
||||
:param is_causal: Whether to apply causal masking
|
||||
:type is_causal: bool
|
||||
:param scale_q: Scaling factor for query tensor
|
||||
:type scale_q: float
|
||||
:param scale_k: Scaling factor for key tensor
|
||||
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
|
||||
:type iterations: int
|
||||
:param skip_ref_check: Skip validation against reference implementation
|
||||
:type skip_ref_check: bool
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
|
||||
:type use_cold_l2: bool
|
||||
|
||||
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
||||
:raises RuntimeError: If GPU is unavailable for computation
|
||||
:return: Execution time of the FMHA kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Blackwell SM100 FMHA test with:")
|
||||
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
|
||||
print(f" pv_acc_dtype: {pv_acc_dtype}")
|
||||
print(f" mma_tiler_mn: {mma_tiler_mn}")
|
||||
print(f" is_persistent: {is_persistent}")
|
||||
print(f" has_casual_mask: {has_casual_mask}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" scale_q: {scale_q}")
|
||||
print(f" scale_k: {scale_k}")
|
||||
print(f" scale_v: {scale_v}")
|
||||
print(f" inv_scale_o: {inv_scale_o}")
|
||||
print(f" scale_softmax: {scale_softmax}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
b, s_q, h_q, d = q_shape
|
||||
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
|
||||
mma_tiler = (*mma_tiler_mn, d)
|
||||
|
||||
mask_type = MaskType.NO_MASK
|
||||
if has_casual_mask:
|
||||
if is_causal:
|
||||
mask_type = MaskType.CAUSAL_MASK
|
||||
else:
|
||||
if isinstance(s_k, tuple):
|
||||
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Execute kernel
|
||||
for _ in range(iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_torch_fmha(
|
||||
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
|
||||
):
|
||||
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
|
||||
h_q = q.shape[2]
|
||||
h_k = k.shape[2]
|
||||
|
||||
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# For the situation that torch has not supported, we need to handle it manually
|
||||
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
|
||||
situation1 = is_causal and (q.is_nested or k.is_nested)
|
||||
situation2 = (q.is_nested and not k.is_nested) or (
|
||||
not q.is_nested and k.is_nested
|
||||
)
|
||||
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref_i = ref_i.transpose(0, 1) * scale_output
|
||||
ref_list.append(ref_i)
|
||||
if q.is_nested:
|
||||
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
|
||||
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
return ref
|
||||
|
||||
if not skip_ref_check:
|
||||
# Execute kernel once for reference checking
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
print("Verifying results...")
|
||||
o_ref = run_torch_fmha(
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
|
||||
)
|
||||
|
||||
if o_ref.is_nested:
|
||||
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
|
||||
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
_, q_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, k_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, v_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, o_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
out_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
return testing.JitArguments(
|
||||
q_tensor_workspace.iterator,
|
||||
k_tensor_workspace.iterator,
|
||||
v_tensor_workspace.iterator,
|
||||
o_tensor_workspace.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
|
||||
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
|
||||
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
|
||||
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
|
||||
one_workspace_bytes = (
|
||||
q_torch_effective.numel() * q_torch_effective.element_size()
|
||||
+ k_torch_effective.numel() * k_torch_effective.element_size()
|
||||
+ v_torch_effective.numel() * v_torch_effective.element_size()
|
||||
+ o_torch_effective.numel() * o_torch_effective.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_fmha,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--has_casual_mask",
|
||||
"--is_causal",
|
||||
action="store_true",
|
||||
help="Whether to use casual mask",
|
||||
)
|
||||
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
|
||||
help="Skip reference check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.q_shape) != 4:
|
||||
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(1111)
|
||||
|
||||
run_fmha_and_verify(
|
||||
run(
|
||||
args.q_shape,
|
||||
args.k_shape,
|
||||
args.in_dtype,
|
||||
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
|
||||
args.pv_acc_dtype,
|
||||
args.mma_tiler_mn,
|
||||
args.is_persistent,
|
||||
args.has_casual_mask,
|
||||
args.is_causal,
|
||||
args.scale_q,
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
@ -157,7 +158,7 @@ class GroupedGemmKernel:
|
||||
self.tmem_ptr_sync_bar_id = 2
|
||||
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
|
||||
self.tensormap_ab_init_bar_id = 4
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
self.num_tma_load_bytes = 0
|
||||
|
||||
def _setup_attributes(self):
|
||||
@ -951,7 +952,7 @@ class GroupedGemmKernel:
|
||||
# Specialized MMA warp
|
||||
#
|
||||
if warp_idx == self.mma_warp_id:
|
||||
# initilize tensormap A, B for TMA warp
|
||||
# initialize tensormap A, B for TMA warp
|
||||
if cutlass.const_expr(self.delegate_tensormap_ab_init):
|
||||
tensormap_manager.init_tensormap_from_atom(
|
||||
tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id
|
||||
@ -1540,11 +1541,7 @@ class GroupedGemmKernel:
|
||||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy(
|
||||
copy_atom_r2s,
|
||||
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_copy_t2r.tiler_mn,
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
@ -1815,7 +1812,136 @@ class GroupedGemmKernel:
|
||||
tensor_memory_management_bytes = 12
|
||||
|
||||
|
||||
def run_grouped_gemm(
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
torch_tensor_cpu: torch.Tensor = None,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
"""Create a GPU tensor from scratch or based on an existing CPU tensor.
|
||||
|
||||
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
|
||||
:type torch_tensor_cpu: torch.Tensor, optional
|
||||
"""
|
||||
if torch_tensor_cpu is None:
|
||||
# Create new CPU tensor
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
|
||||
# Create GPU tensor from CPU tensor (new or existing)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
|
||||
|
||||
def create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl: List[tuple[int, int, int, int]],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
torch_fp32_tensors_abc: List[List[torch.Tensor]] = None,
|
||||
) -> tuple[
|
||||
List[List[int]],
|
||||
List[List[torch.Tensor]],
|
||||
List[tuple],
|
||||
List[List[tuple]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
|
||||
problem_sizes_mnkl
|
||||
):
|
||||
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
|
||||
|
||||
# Initialize lists to store tensors for all groups
|
||||
new_torch_fp32_tensors_abc = (
|
||||
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
|
||||
)
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
|
||||
# Iterate through all groups and create tensors for each group
|
||||
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
# Get existing CPU tensors if available, otherwise None
|
||||
existing_cpu_a = (
|
||||
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_b = (
|
||||
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
existing_cpu_c = (
|
||||
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
|
||||
)
|
||||
|
||||
# Create tensors (reusing CPU tensors if provided)
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a
|
||||
)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(
|
||||
l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b
|
||||
)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(
|
||||
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
|
||||
)
|
||||
|
||||
# Only append to new_torch_fp32_tensors_abc if we created new CPU tensors
|
||||
if torch_fp32_tensors_abc is None:
|
||||
new_torch_fp32_tensors_abc.append(
|
||||
[tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]
|
||||
)
|
||||
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
new_torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
num_groups: int,
|
||||
problem_sizes_mnkl: tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
@ -1832,8 +1958,16 @@ def run_grouped_gemm(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run grouped GEMM example with specified configurations."""
|
||||
"""Run grouped GEMM example with specified configurations.
|
||||
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
print(f"Running Blackwell Grouped GEMM test with:")
|
||||
print(f"{num_groups} groups")
|
||||
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
@ -1847,6 +1981,7 @@ def run_grouped_gemm(
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Skip unsupported types
|
||||
if ab_dtype not in {
|
||||
@ -1902,66 +2037,22 @@ def run_grouped_gemm(
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run this example!")
|
||||
|
||||
# Create tensor and return the pointer, tensor, and stride
|
||||
def create_tensor_and_stride(
|
||||
l: int,
|
||||
mode0: int,
|
||||
mode1: int,
|
||||
is_mode0_major: bool,
|
||||
dtype: type[cutlass.Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
|
||||
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
|
||||
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
||||
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
|
||||
)
|
||||
return (
|
||||
torch_tensor.data_ptr(),
|
||||
torch_tensor,
|
||||
cute_tensor,
|
||||
torch_tensor_cpu,
|
||||
torch_tensor.stride()[:-1],
|
||||
)
|
||||
# Create tensors for all groups using the new function
|
||||
(
|
||||
ptrs_abc,
|
||||
torch_tensors_abc,
|
||||
cute_tensors_abc,
|
||||
strides_abc,
|
||||
torch_fp32_tensors_abc,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
|
||||
# iterate all groups and create tensors for each group
|
||||
torch_fp32_tensors_abc = []
|
||||
torch_tensors_abc = []
|
||||
cute_tensors_abc = []
|
||||
strides_abc = []
|
||||
ptrs_abc = []
|
||||
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
|
||||
(
|
||||
ptr_a,
|
||||
torch_tensor_a,
|
||||
cute_tensor_a,
|
||||
tensor_fp32_a,
|
||||
stride_mk_a,
|
||||
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
|
||||
(
|
||||
ptr_b,
|
||||
torch_tensor_b,
|
||||
cute_tensor_b,
|
||||
tensor_fp32_b,
|
||||
stride_nk_b,
|
||||
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
|
||||
(
|
||||
ptr_c,
|
||||
torch_tensor_c,
|
||||
cute_tensor_c,
|
||||
tensor_fp32_c,
|
||||
stride_mn_c,
|
||||
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
|
||||
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
|
||||
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
|
||||
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
|
||||
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
|
||||
cute_tensors_abc.append(
|
||||
(
|
||||
cute_tensor_a,
|
||||
cute_tensor_b,
|
||||
cute_tensor_c,
|
||||
)
|
||||
)
|
||||
# Choose A, B, C with the smallest size to create initial tensormaps
|
||||
key_size_a = lambda item: item[1][0] * item[1][2]
|
||||
key_size_b = lambda item: item[1][1] * item[1][2]
|
||||
@ -2078,36 +2169,19 @@ def run_grouped_gemm(
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Launch GPU kernel
|
||||
# Warm up
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
# Execution
|
||||
for i in range(iterations):
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute reference result
|
||||
if not skip_ref_check:
|
||||
compiled_grouped_gemm(
|
||||
initial_cute_tensors_abc[0],
|
||||
initial_cute_tensors_abc[1],
|
||||
initial_cute_tensors_abc[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc,
|
||||
tensor_of_ptrs_abc,
|
||||
tensor_of_tensormap,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Compute reference result
|
||||
for i, (a, b, c) in enumerate(torch_tensors_abc):
|
||||
ref = torch.einsum(
|
||||
"mkl,nkl->mnl",
|
||||
@ -2122,6 +2196,102 @@ def run_grouped_gemm(
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU tensors and create new GPU tensors from them
|
||||
(
|
||||
ptrs_abc_workspace,
|
||||
torch_tensors_abc_workspace,
|
||||
cute_tensors_abc_workspace,
|
||||
strides_abc_workspace,
|
||||
_,
|
||||
) = create_tensors_for_all_groups(
|
||||
problem_sizes_mnkl,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
torch_fp32_tensors_abc,
|
||||
)
|
||||
|
||||
initial_cute_tensors_abc_workspace = [
|
||||
cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k)
|
||||
cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k)
|
||||
cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n)
|
||||
]
|
||||
|
||||
# Create new tensors for this workspace
|
||||
tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(strides_abc_workspace, dtype=torch.int32),
|
||||
cutlass.Int32,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.tensor(ptrs_abc_workspace, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
tensormap_workspace, _ = cutlass_torch.cute_tensor_like(
|
||||
torch.empty(tensormap_shape, dtype=torch.int64),
|
||||
cutlass.Int64,
|
||||
is_dynamic_layout=False,
|
||||
)
|
||||
|
||||
return testing.JitArguments(
|
||||
initial_cute_tensors_abc_workspace[0],
|
||||
initial_cute_tensors_abc_workspace[1],
|
||||
initial_cute_tensors_abc_workspace[2],
|
||||
tensor_of_dim_size_mnkl,
|
||||
tensor_of_strides_abc_workspace,
|
||||
tensor_of_ptrs_abc_workspace,
|
||||
tensormap_workspace,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
sum(
|
||||
[
|
||||
sum(
|
||||
[
|
||||
torch_tensor.numel() * torch_tensor.element_size()
|
||||
for torch_tensor in group_tensors
|
||||
]
|
||||
)
|
||||
for group_tensors in torch_tensors_abc
|
||||
]
|
||||
)
|
||||
+
|
||||
# Add size of strides tensor
|
||||
tensor_of_strides_abc_torch.numel()
|
||||
* tensor_of_strides_abc_torch.element_size()
|
||||
+
|
||||
# Add size of ptrs tensor
|
||||
tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size()
|
||||
+
|
||||
# Add size of tensormap tensor
|
||||
tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_grouped_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2218,6 +2388,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -2248,7 +2424,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
run_grouped_gemm(
|
||||
run(
|
||||
args.num_groups,
|
||||
args.problem_sizes_mnkl,
|
||||
args.ab_dtype,
|
||||
@ -2265,5 +2441,6 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -29,13 +29,14 @@
|
||||
|
||||
import argparse
|
||||
from typing import List, Type, Tuple, Optional
|
||||
from cuda import cuda
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
from .mamba2_ssd_reference import (
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
from mamba2_ssd_reference import (
|
||||
ssd_reference_fp32_all,
|
||||
ssd_reference_lowprecision_intermediates,
|
||||
analyze_relative_diffs,
|
||||
)
|
||||
|
||||
from .mamba2_ssd_tile_scheduler import (
|
||||
from mamba2_ssd_tile_scheduler import (
|
||||
Mamba2SSDTileSchedulerParams,
|
||||
Mamba2SSDTileScheduler,
|
||||
)
|
||||
@ -122,7 +126,7 @@ class SSDKernel:
|
||||
*self.epilog_warp_id,
|
||||
)
|
||||
)
|
||||
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
# Named barriers
|
||||
self.pre_inter_sync_bar_id = 1
|
||||
@ -1522,7 +1526,10 @@ class SSDKernel:
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
|
||||
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
|
||||
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
|
||||
local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r
|
||||
local_tidx,
|
||||
smem_bt_internal_,
|
||||
tiled_s2r_b,
|
||||
tBrB_s2r,
|
||||
)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
|
||||
@ -3053,7 +3060,7 @@ class SSDKernel:
|
||||
|
||||
# SegSum
|
||||
# fadd2 + fsel + fmul2/mufu + fmul2
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
tCompute[subtile_idx + 1],
|
||||
@ -3061,11 +3068,11 @@ class SSDKernel:
|
||||
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
|
||||
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
|
||||
)
|
||||
for subtile_idx in range(cute.size(tTR_rQ)):
|
||||
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
|
||||
m, n = tCoord[subtile_idx]
|
||||
if m < n:
|
||||
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
|
||||
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
|
||||
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
|
||||
# TODO: use math.exp directly
|
||||
(
|
||||
tCompute[subtile_idx],
|
||||
@ -3130,11 +3137,7 @@ class SSDKernel:
|
||||
dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy(
|
||||
copy_atom_r2s_b,
|
||||
layout_tv=tiled_s2r_b.layout_tv_tiled,
|
||||
tiler_mn=tiled_s2r_b.tiler_mn,
|
||||
)
|
||||
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
|
||||
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
|
||||
|
||||
# Partition shared tensor for smem store Bt
|
||||
@ -3333,17 +3336,24 @@ class SSDKernel:
|
||||
)
|
||||
|
||||
|
||||
def run_ssd(
|
||||
def run(
|
||||
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
|
||||
io_dtype: Type[cutlass.Numeric],
|
||||
cumsum_delta_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
has_d: bool,
|
||||
d_has_hdim: bool,
|
||||
fuse_scale_d: str,
|
||||
tolerance: float,
|
||||
print_rtol_stats: bool,
|
||||
ref_lower_precision: bool,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
has_d = fuse_scale_d != "none"
|
||||
d_has_hdim = fuse_scale_d == "vector"
|
||||
|
||||
print(f"Running B100 Mamba2 SSD with:")
|
||||
print(f"GBEHCDLN: {gbehcdln}")
|
||||
print(
|
||||
@ -3353,6 +3363,10 @@ def run_ssd(
|
||||
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
|
||||
)
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
|
||||
# Unpack parameters
|
||||
G, B, E, H, C, D, L, N = gbehcdln
|
||||
@ -3515,39 +3529,146 @@ def run_ssd(
|
||||
stream,
|
||||
)
|
||||
|
||||
# Launch compiled ssd kernel
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
# Launch compiled ssd kernel for reference check
|
||||
if not skip_ref_check:
|
||||
compiled_ssd(
|
||||
x_tensor,
|
||||
cumsum_delta_tensor,
|
||||
delta_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
y_tensor,
|
||||
fstate_tensor,
|
||||
d_tensor,
|
||||
stream,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
|
||||
def generate_tensors():
|
||||
# Reuse existing CPU reference tensors and create new GPU tensors from them
|
||||
_, x_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[2, 4, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=x_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
cumsum_delta_dtype,
|
||||
ref_tensor=cumsum_delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, delta_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, C, L],
|
||||
[3, 2, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=delta_ref,
|
||||
dynamic_modes=[1, 2, 3],
|
||||
)
|
||||
_, b_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=b_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, c_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, G, N, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=c_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, y_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, C, L],
|
||||
[4, 2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=y_ref,
|
||||
dynamic_modes=[2, 3, 4],
|
||||
)
|
||||
_, fstate_tensor_new, _ = create_and_permute_tensor(
|
||||
[B, EH, D, N],
|
||||
[2, 3, 1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=fstate_ref,
|
||||
dynamic_modes=[2, 3],
|
||||
)
|
||||
|
||||
if has_d:
|
||||
_, d_tensor_new, _ = create_and_permute_tensor(
|
||||
[EH, D if d_has_hdim else 1],
|
||||
[1, 0],
|
||||
io_dtype,
|
||||
ref_tensor=d_ref,
|
||||
dynamic_modes=[1],
|
||||
)
|
||||
else:
|
||||
d_tensor_new = d_tensor
|
||||
|
||||
return testing.JitArguments(
|
||||
x_tensor_new,
|
||||
cumsum_delta_tensor_new,
|
||||
delta_tensor_new,
|
||||
b_tensor_new,
|
||||
c_tensor_new,
|
||||
y_tensor_new,
|
||||
fstate_tensor_new,
|
||||
d_tensor_new,
|
||||
stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
x_torch.numel() * x_torch.element_size()
|
||||
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
|
||||
+ delta_torch.numel() * delta_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
+ y_torch.numel() * y_torch.element_size()
|
||||
+ fstate_torch.numel() * fstate_torch.element_size()
|
||||
)
|
||||
if has_d:
|
||||
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
|
||||
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_ssd,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Reference check
|
||||
if print_rtol_stats:
|
||||
print("\nY's Relative diffs:")
|
||||
analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)))
|
||||
print("\nFstate's Relative diffs:")
|
||||
analyze_relative_diffs(
|
||||
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
y_torch.cpu(),
|
||||
y_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-02,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
fstate_torch.cpu(),
|
||||
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
|
||||
atol=tolerance,
|
||||
rtol=1e-05,
|
||||
)
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -3586,15 +3707,53 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_lower_precision",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ref_lower_precision",
|
||||
action="store_false",
|
||||
dest="ref_lower_precision",
|
||||
default=False,
|
||||
help="Disable lower precision for reference check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_rtol_stats", type=bool, default=True, help="Print rtol stats"
|
||||
"--print_rtol_stats",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-print_rtol_stats",
|
||||
action="store_false",
|
||||
dest="print_rtol_stats",
|
||||
default=False,
|
||||
help="Disable print rtol stats",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of warmup iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -3602,18 +3761,18 @@ if __name__ == "__main__":
|
||||
if len(args.gbehcdln) != 8:
|
||||
parser.error("--gbehcdln must contain exactly 8 values")
|
||||
|
||||
has_d = args.fuse_scale_d != "none"
|
||||
d_has_hdim = args.fuse_scale_d == "vector"
|
||||
|
||||
run_ssd(
|
||||
run(
|
||||
args.gbehcdln,
|
||||
args.io_dtype,
|
||||
args.cumsum_delta_dtype,
|
||||
args.acc_dtype,
|
||||
has_d,
|
||||
d_has_hdim,
|
||||
args.fuse_scale_d,
|
||||
args.tolerance,
|
||||
args.print_rtol_stats,
|
||||
args.ref_lower_precision,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -35,6 +35,7 @@ import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations to run the kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
|
||||
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
||||
self.num_threads_per_warp_group = 128
|
||||
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
|
||||
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
|
||||
|
||||
self.ab_stage = None
|
||||
self.epi_stage = None
|
||||
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
|
||||
}:
|
||||
is_valid = False
|
||||
# tested acc_dtype
|
||||
if acc_dtype != cutlass.Float32:
|
||||
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
|
||||
is_valid = False
|
||||
# tested c_dtype
|
||||
if c_dtype not in {
|
||||
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
|
||||
return is_valid
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
def run(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
a_dtype: Type[cutlass.Numeric],
|
||||
b_dtype: Type[cutlass.Numeric],
|
||||
@ -1347,9 +1366,43 @@ def run_dense_gemm(
|
||||
tile_shape_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
tolerance: float,
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
||||
|
||||
:param mnkl: Problem size (M, N, K, L)
|
||||
:type mnkl: Tuple[int, int, int, int]
|
||||
:param a_dtype: Data type for input tensor A
|
||||
:type a_dtype: Type[cutlass.Numeric]
|
||||
:param b_dtype: Data type for input tensor B
|
||||
:type b_dtype: Type[cutlass.Numeric]
|
||||
:param c_dtype: Data type for output tensor C
|
||||
:type c_dtype: Type[cutlass.Numeric]
|
||||
:param acc_dtype: Data type for accumulation during matrix multiplication
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mn: Cluster shape (M, N)
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param tolerance: Tolerance value for reference validation comparison
|
||||
:type tolerance: float
|
||||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||||
:type warmup_iterations: int, optional
|
||||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||||
:type iterations: int, optional
|
||||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||||
:type use_cold_l2: bool, optional
|
||||
:return: Execution time of the GEMM kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Hopper Dense GEMM with:")
|
||||
@ -1360,6 +1413,10 @@ def run_dense_gemm(
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
@ -1437,46 +1494,76 @@ def run_dense_gemm(
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# compile gemm kernel
|
||||
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if not skip_ref_check:
|
||||
# execution
|
||||
compiled_gemm(mA, mB, mC, stream)
|
||||
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# k major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
# Ref check
|
||||
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
|
||||
|
||||
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
|
||||
# m major: (l, n, m) -> (m, n, l)
|
||||
# n major: (l, m, n) -> (m, n, l)
|
||||
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
||||
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
||||
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
||||
shape,
|
||||
torch.uint8,
|
||||
permute_order=permute_order,
|
||||
init_type=cutlass_torch.TensorInitType.SKIP,
|
||||
).cuda()
|
||||
# Create dtype cute tensor (gpu)
|
||||
ref_c_tensor = from_dlpack(
|
||||
f8_torch_tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
ref_c_tensor.element_type = c_dtype
|
||||
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
||||
ref,
|
||||
ref_c_tensor,
|
||||
c_dtype,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
|
||||
def generate_tensors():
|
||||
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
|
||||
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
||||
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
||||
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
one_workspace_bytes = (
|
||||
a_torch.numel() * a_torch.element_size()
|
||||
+ b_torch.numel() * b_torch.element_size()
|
||||
+ c_torch.numel() * c_torch.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
ref_c = f8_torch_tensor.cpu()
|
||||
else:
|
||||
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
|
||||
exec_time = testing.benchmark(
|
||||
compiled_gemm,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
run_dense_gemm(
|
||||
run(
|
||||
args.mnkl,
|
||||
args.a_dtype,
|
||||
args.b_dtype,
|
||||
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
|
||||
args.tile_shape_mnk,
|
||||
args.cluster_shape_mn,
|
||||
args.tolerance,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@ -399,6 +399,70 @@
|
||||
"\n",
|
||||
"tensor_print_example3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(src: cute.Tensor):\n",
|
||||
" print(src)\n",
|
||||
" cute.print_tensor(src)\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_host(src: cute.Tensor):\n",
|
||||
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[-0.690547, -0.274619, -1.659539, ],\n",
|
||||
" [-1.843524, -1.648711, 1.163431, ],\n",
|
||||
" [-0.716668, -1.900705, 0.592515, ],\n",
|
||||
" [ 0.711333, -0.552422, 0.860237, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"def tensor_print_example4():\n",
|
||||
" a = torch.randn(4, 3, device=\"cuda\")\n",
|
||||
" cutlass.cuda.initialize_cuda_context()\n",
|
||||
" print_tensor_host(from_dlpack(a))\n",
|
||||
"\n",
|
||||
"tensor_print_example4()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -256,16 +256,6 @@
|
||||
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
|
||||
"\n",
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
"\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.print_tensor(tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a tensor with sequential data using torch\n",
|
||||
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",
|
||||
|
||||
@ -363,7 +363,7 @@
|
||||
"| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n",
|
||||
"| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n",
|
||||
"|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
|
||||
"| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
|
||||
"| | \"optimized\" | Optimized for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
|
||||
"|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
|
||||
"| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user