CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
//

View File

@ -231,7 +231,7 @@ struct B2bFusedGroupedGemmRun
host_tensor_ref_D1.at(i).sync_device();
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
if (alpha0 == ElementCompute(0)) //per-channel scale
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
@ -340,7 +340,7 @@ struct B2bFusedGroupedGemmRun
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
for (int i = 0; i < problem_count; ++i) {
host_tensor_D1.at(i).sync_host();;
host_tensor_D1.at(i).sync_host();
//
// Verify

View File

@ -194,7 +194,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -33,6 +33,11 @@ cutlass_example_add_executable(
ampere_sparse_tensorop_gemm.cu
)
cutlass_example_add_executable(
15_ampere_sparse_tensorop_gemm_universal
ampere_sparse_tensorop_gemm_universal.cu
)
cutlass_example_add_executable(
15_ampere_sparse_tensorop_gemm_with_visitor
ampere_sparse_tensorop_gemm_with_visitor.cu

View File

@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -0,0 +1,329 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
meta data is different for every data types. CUTLASS templates can automatically infer it based on
input A and B. Check code below.
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
efficiently.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/host_uncompress.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = int32_t; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
using ElementOutput = int32_t; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Row Major for
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 3;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
// Data type and layout of meta data matrix E can be inferred from template Gemm.
using ElementInputE = typename Gemm::ElementE;
using LayoutInputE = cutlass::layout::RowMajor;
using ReorderedLayoutInputE = typename Gemm::LayoutE;
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
// 50% Sparsity on Ampere
constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
int run() {
const int length_m = 512;
const int length_n = 512;
const int length_k = 1024;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Same size as the above. The above one needs to be reordered and stored in this one.
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomSparseMeta(
tensor_e.host_view(),
1,
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
// instructions.
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
{problem_size.m(), problem_size.n(),
problem_size.k() / kSparse / kElementsPerElementE});
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_e_reordered.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Split K dimension into 1 partitions
int split_k_slices = 2;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
problem_size, // <- problem size of matrix multiplication
split_k_slices,// <- k-dimension split factor
{alpha, beta}, // <- tuple of alpha and beta
tensor_a.device_data(), // <- reference to matrix A on device
tensor_b.device_data(), // <- reference to matrix B on device
tensor_c.device_data(), // <- reference to matrix C on device
tensor_d.device_data(), // <- reference to matrix D on device
tensor_e_reordered.device_data(), // <- reference to matrix E on device
int64_t(),
int64_t(),
int64_t(),
int64_t(),
int64_t(),
tensor_a.layout().stride(0),
tensor_b.layout().stride(0),
tensor_c.layout().stride(0),
tensor_d.layout().stride(0),
tensor_e_reordered.layout().stride(0)
};
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
tensor_e.host_ref(), problem_size.m(), problem_size.k());
// Create instantiation for host reference gemm kernel
cutlass::reference::host::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue,
typename Gemm::Operator>
gemm_host;
// Launch host reference gemm kernel
gemm_host(problem_size,
alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
beta,
tensor_c.host_ref(),
tensor_ref_d.host_ref());
// Copy output data from CUTLASS host for comparison
tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
bool notSupported = false;
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.1.
//
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major * 10 + props.minor < 80) {
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
}
if (notSupported) {
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
return run();
}

View File

@ -94,7 +94,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 1
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
using Operator = cutlass::arch::OpMultiplyAdd;
using Operator = cutlass::arch::OpMultiplyAddSaturate;
// Number of pipelines you want to use
constexpr int NumStages = 3;

View File

@ -138,7 +138,7 @@ using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
>;
// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
@ -154,7 +154,7 @@ using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK<ReduceGemmSplitKKernel>;
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;;
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;
// This code section describes the epilogue part of the kernel, we use default value
using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -258,7 +258,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -221,7 +221,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -193,7 +193,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

View File

@ -215,7 +215,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8
// 16, 8, 16 -> Ampere
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Define the epilogue operation as LinearCombination. This is approximately equal to
//

View File

@ -118,7 +118,7 @@ operation = Conv2dOperation(
conv_kind=cutlass_bindings.conv.Operator.fprop,
iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized,
arch=cc, tile_description=tile_description,
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
A=A, B=B, C=C, stride_support=StrideSupport.Unity,
epilogue_functor=epilogue_functor
)

View File

@ -1799,7 +1799,7 @@ struct B2bGemm<
if (rowIdx == 1) {
lse_prefetched[colIdx] = accum_n < lse_extent
? lse[accum_n]
: platform::numeric_limits<accum_t>::infinity();
: cutlass::platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;
@ -1938,7 +1938,7 @@ struct B2bGemm<
if (rowIdx == 1) {
lse_prefetched[colIdx] = accum_n < lse_extent
? lse[accum_n]
: platform::numeric_limits<accum_t>::infinity();
: cutlass::platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;

View File

@ -35,19 +35,23 @@
This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and
convert them implicitly to TF32.
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
*/
#include <iostream>
@ -63,6 +67,7 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
@ -105,7 +110,7 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
@ -175,6 +180,8 @@ cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
// Command line options parsing
struct Options {
@ -183,12 +190,16 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
RasterOrderOptions raster;
int swizzle;
Options():
help(false),
m(5120), n(4096), k(4096),
alpha(1.f), beta(0.f),
iterations(1000)
iterations(1000),
raster(RasterOrderOptions::Heuristic),
swizzle(1)
{ }
// Parses the command line
@ -206,6 +217,21 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster = RasterOrderOptions::AlongM;
}
else if (raster_char == 'H' || raster_char == 'h') {
raster = RasterOrderOptions::Heuristic;
}
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
}
/// Prints the usage statement.
@ -220,6 +246,8 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
@ -294,10 +322,10 @@ bool initialize_block(
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
block_A.reset(options.m * options.k);
block_B.reset(options.k * options.n);
@ -320,6 +348,10 @@ typename Gemm::Arguments args_from_options(const Options &options)
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
arguments.scheduler.raster_order = options.raster;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
@ -408,7 +440,17 @@ int run(Options &options)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
@ -441,7 +483,6 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//

View File

@ -538,9 +538,8 @@ int main(int argc, char const **args) {
std::cout
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
return 0;
}
//
// Parse options
//

View File

@ -354,9 +354,8 @@ int main(int argc, char const **args) {
std::cout
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
return 0;
return 0;
}
//
// Parse options
//

View File

@ -627,7 +627,6 @@ int main(int argc, const char ** argv) {
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
notSupported = true;
}
if (notSupported) {
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}

View File

@ -166,7 +166,7 @@ public:
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
auto problem_shape = args.problem_shape;
if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
// swap M/N
get<0>(problem_shape) = get<1>(args.problem_shape);
get<1>(problem_shape) = get<0>(args.problem_shape);
@ -181,8 +181,7 @@ public:
};
}
CUTLASS_HOST_DEVICE static
bool
static bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);

View File

@ -119,7 +119,7 @@ public:
}
template<class ProblemShape>
CUTLASS_HOST_DEVICE static bool
static bool
can_implement(
[[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {

View File

@ -750,7 +750,6 @@ int main(int argc, char const **argv)
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
notSupported = true;
}
if (notSupported) {
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}

View File

@ -47,9 +47,13 @@
4. This example shows all important fusions used by FP8 gemm kernels,
i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor.
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
*/
#include <iostream>
@ -63,6 +67,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
@ -214,6 +219,8 @@ cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
/// Result structure
struct Result
{
@ -273,7 +280,7 @@ bool initialize_tensor(
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
void initialize(const Options<RasterOrderOptions> &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
@ -346,7 +353,7 @@ void initialize(const Options &options) {
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
@ -392,10 +399,14 @@ typename Gemm::Arguments args_from_options(const Options &options)
fusion_args.amax_D_ptr = abs_max_D.device_data();
}
arguments.scheduler.raster_order = options.raster;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
bool verify(const Options &options) {
bool verify(const Options<RasterOrderOptions> &options) {
//
// Compute reference output
//
@ -468,7 +479,7 @@ bool verify(const Options &options) {
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
int run(Options<RasterOrderOptions> &options)
{
initialize(options);
@ -518,7 +529,17 @@ int run(Options &options)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
@ -551,12 +572,11 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
Options<RasterOrderOptions> options;
options.parse(argc, args);

View File

@ -30,6 +30,7 @@
**************************************************************************************************/
// Command line options parsing
template<typename RasterOrderOptions>
struct Options {
bool help = false;
@ -41,6 +42,8 @@ struct Options {
bool save_amax = true;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
RasterOrderOptions raster;
int swizzle;
// Parses the command line
void parse(int argc, char const **args) {
@ -66,6 +69,21 @@ struct Options {
cmd.get_cmd_line_argument("save_aux", save_aux, true);
cmd.get_cmd_line_argument("save_amax", save_amax, true);
cmd.get_cmd_line_argument("iterations", iterations);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster = RasterOrderOptions::AlongM;
}
else if (raster_char == 'H' || raster_char == 'h') {
raster = RasterOrderOptions::Heuristic;
}
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
}
/// Prints the usage statement.
@ -89,6 +107,8 @@ struct Options {
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out

View File

@ -687,7 +687,6 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//

View File

@ -99,7 +99,7 @@ using TileShape = Shape<_256,_128,_64>; // T
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
@ -492,7 +492,6 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//

View File

@ -30,10 +30,10 @@
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes

View File

@ -32,7 +32,7 @@
/*! \file
\brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
warp-specialized cooperative kernel.
For this example all scheduling work is performed on the device.
The new feature showcased in this example is on-the-fly modification of TMA descriptors
@ -42,7 +42,7 @@
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups.
Same applies for alpha and beta values that are randomized across the different groups.
@ -117,7 +117,7 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
@ -163,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
// Host-side allocations
std::vector<int64_t> offset_A;
@ -226,7 +226,7 @@ struct Options {
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
// Parses the command line
void parse(int argc, char const **args) {
@ -438,10 +438,10 @@ void allocate(const Options &options) {
total_elements_C += elements_C;
total_elements_D += elements_D;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
}
@ -456,7 +456,7 @@ void allocate(const Options &options) {
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
@ -695,7 +695,6 @@ int main(int argc, char const **args) {
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//

View File

@ -97,7 +97,7 @@ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWith
cutlass::epilogue::thread::ReLu,
ElementOutput,
ElementAuxOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
8,
ElementAccumulator,
ElementAccumulator
>;
@ -106,7 +106,7 @@ template <typename MathOperator>
using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89,
cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages,
kAlignmentA, kAlignmentB, MathOperator
>;

View File

@ -53,3 +53,8 @@ cutlass_example_add_executable(
tiled_copy.cu
)
cutlass_example_add_executable(
wgmma_sm90
wgmma_sm90.cu
)

View File

@ -153,12 +153,12 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
// Clear the accumulators
clear(tCrC);
@ -358,7 +358,7 @@ gemm_nt(int m, int n, int k,
alpha, beta);
}
// Setup params for a NT GEMM
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
@ -391,10 +391,10 @@ gemm_tn(int m, int n, int k,
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
[[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx

View File

@ -0,0 +1,562 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/device_kernel.h"
using namespace cute;
template <class ElementA,
class ElementB,
class SmemLayoutA, // (M,K,P)
class SmemLayoutB> // (N,K,P)
struct SharedStorage
{
array_aligned<ElementA, cosize_v<SmemLayoutA>> smem_A;
array_aligned<ElementB, cosize_v<SmemLayoutB>> smem_B;
uint64_t tma_barrier[size<2>(SmemLayoutA{})];
uint64_t mma_barrier[size<2>(SmemLayoutA{})];
};
template <class ProblemShape, class CtaTiler,
class TA, class SmemLayoutA, class TmaA,
class TB, class SmemLayoutB, class TmaB,
class TC, class CStride, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, CUTLASS_GRID_CONSTANT TmaA const tma_a,
TB const* B, CUTLASS_GRID_CONSTANT TmaB const tma_b,
TC * C, CStride dC, TiledMma mma,
Alpha alpha, Beta beta)
{
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
static_assert(is_static<SmemLayoutA>::value);
static_assert(is_static<SmemLayoutB>::value);
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutA{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutB{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutA{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutB{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
auto [M, N, K] = shape_MNK;
Tensor mA = tma_a.get_tma_tensor(make_shape(M,K)); // (M,K) TMA Tensor
Tensor mB = tma_b.get_tma_tensor(make_shape(N,K)); // (N,K) TMA Tensor
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory tensors
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, SmemLayoutA, SmemLayoutB>;
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles
//
// TUTORIAL:
// These are TMA partitionings, which have a dedicated custom partitioner.
// The Int<0>, Layout<_1> indicates that the TMAs are not multicasted.
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
// The group_modes<0,2> transforms the (X,Y,Z)-shaped tensors into ((X,Y),Z)-shaped tensors
// with the understanding that the TMA is responsible for everything in mode-0.
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
//
auto [tAgA, tAsA] = tma_partition(tma_a, Int<0>{}, Layout<_1>{},
group_modes<0,2>(sA), group_modes<0,2>(gA)); // (TMA,k) and (TMA,PIPE)
auto [tBgB, tBsB] = tma_partition(tma_b, Int<0>{}, Layout<_1>{},
group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE)
// The TMA is responsible for copying everything in mode-0 of tAsA and tBsB
constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) +
CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB);
//
// PREFETCH
//
auto K_PIPE_MAX = size<1>(tAsA);
// Total count of tiles
int k_tile_count = size<1>(tAgA);
// Current tile index in gmem to read from
int k_tile = 0;
// Initialize Barriers
int warp_idx = cutlass::canonical_warp_idx_sync();
int lane_predicate = cute::elect_one_sync();
uint64_t* producer_mbar = smem.tma_barrier;
uint64_t* consumer_mbar = smem.mma_barrier;
using ProducerBarType = cutlass::arch::ClusterTransactionBarrier; // TMA
using ConsumerBarType = cutlass::arch::ClusterBarrier; // MMA
CUTE_UNROLL
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) {
if ((warp_idx == 0) && lane_predicate) {
ProducerBarType::init(&producer_mbar[pipe], 1);
ConsumerBarType::init(&consumer_mbar[pipe], 128);
}
}
// Ensure barrier init is complete on all CTAs
cluster_sync();
// Start async loads for all pipes
CUTE_UNROLL
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe)
{
if ((warp_idx == 0) && lane_predicate)
{
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
}
--k_tile_count;
++k_tile;
}
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL:
// The tCrA and tCrB are actually Tensors of MMA Descriptors constructed as views of SMEM.
// The MMA Descriptor generation is automatic via inspection and validation of the SMEM Layouts.
// Because the MMA reads directly from SMEM and the fragments are descriptors rather than registers,
// there is no need for copy(tCsA, tCrA) in the mainloop.
//
ThrMMA thr_mma = mma.get_thread_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate accumulators and clear them
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
clear(tCrC);
// Allocate "fragments"
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
//
// PIPELINED MAIN LOOP
//
// TUTORIAL:
// Rather than interleaving the stages and instructions like in SM70 and SM80,
// the SM90 mainloops rely on explicit producer-consumer synchronization
// on the purely async instructions TMA and MMA.
// More advanced pipeline and warp-specialization strategies are available in CUTLASS mainloops.
//
// A PipelineState is a circular pipe index [.index()] and a pipe phase [.phase()]
// that flips each cycle through K_PIPE_MAX.
auto write_state = cutlass::PipelineState<K_PIPE_MAX>(); // TMA writes
auto read_state = cutlass::PipelineState<K_PIPE_MAX>(); // MMA reads
CUTE_NO_UNROLL
while (k_tile_count > -K_PIPE_MAX)
{
// Wait for Producer to complete
int read_pipe = read_state.index();
ProducerBarType::wait(&producer_mbar[read_pipe], read_state.phase());
// MMAs to cover 1 K_TILE
warpgroup_arrive();
gemm(mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC); // (V,M) x (V,N) => (V,M,N)
warpgroup_commit_batch();
// Wait for all MMAs in a K_TILE to complete
warpgroup_wait<0>();
// Notify that consumption is done
ConsumerBarType::arrive(&consumer_mbar[read_pipe]);
++read_state;
if ((warp_idx == 0) && lane_predicate)
{
int pipe = write_state.index();
// Wait for Consumer to complete consumption
ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase());
// Set expected Tx Bytes after each reset / init
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
++write_state;
}
--k_tile_count;
++k_tile;
}
//
// Epilogue (unpredicated)
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for an NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int< 3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the MMA
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
// Define the TMAs
// Create Global memory tensors for TMA inspection
Tensor mA = make_tensor(A, make_shape(M,K), dA);
Tensor mB = make_tensor(B, make_shape(N,K), dB);
// Create TMA Atoms with the desired copy operation on the source and destination
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
//
// Setup and Launch
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(2, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
prob_shape, cta_tiler,
A, tmaA,
B, tmaB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
// Define the MMA
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::K,GMMA::Major::K>{});
// Define the TMAs
// Create Global memory tensors for TMA inspection
Tensor mA = make_tensor(A, make_shape(M,K), dA);
Tensor mB = make_tensor(B, make_shape(N,K), dB);
// Create TMA Atoms with the desired copy operation on the source and destination
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
//
// Setup and Launch
//
// Launch parameter setup
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(tiled_mma));
dim3 dimCluster(2, 1, 1);
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
round_up(size(ceil_div(n, bN)), dimCluster.y));
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
void const* kernel_ptr = reinterpret_cast<void const*>(
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
TA, decltype(sA), decltype(tmaA),
TB, decltype(sB), decltype(tmaB),
TC, decltype(dC), decltype(tiled_mma),
decltype(alpha), decltype(beta)>);
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
// Kernel Launch
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
prob_shape, cta_tiler,
A, tmaA,
B, tmaB,
C, dC, tiled_mma,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major != 9) {
std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a\n" << std::endl;
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
int m = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 256;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 1024;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = cute::half_t;
using TB = cute::half_t;
using TC = cute::half_t;
using TI = cute::half_t;
TI alpha = TI(1.0f);
TI beta = TI(0.0f);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
// Initialize the tensors
for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1));
for (int j = 0; j < m*n; ++j) h_C[j] = TC(0);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#else
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}