608 lines
22 KiB
Plaintext
608 lines
22 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
\brief A FP16 sparse GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
|
|
|
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
|
|
|
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
|
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
|
|
|
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
|
|
|
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
|
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
|
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
|
|
|
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
|
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
|
|
|
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
|
|
|
Usage:
|
|
$ ./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm --m=8192 --n=8192 --k=8192
|
|
*/
|
|
|
|
#include <iostream>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
#include "cute/tensor.hpp"
|
|
#include "cutlass/tensor_ref.h"
|
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
|
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
|
|
|
#include "cutlass/util/command_line.h"
|
|
#include "cutlass/util/distribution.h"
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/packed_stride.hpp"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/util/reference/device/gemm.h"
|
|
#include "cutlass/util/reference/device/tensor_compare.h"
|
|
#include "cutlass/util/reference/device/tensor_fill.h"
|
|
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
|
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
|
|
|
#include "helper.h"
|
|
|
|
using namespace cute;
|
|
|
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// GEMM kernel configurations
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// A matrix configuration
|
|
using ElementA = half_t; // Element type for A matrix operand
|
|
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
|
constexpr int AlignmentA = 2 * 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k
|
|
|
|
// E matrix config
|
|
using ElementE = cute::uint8_t;
|
|
|
|
// B matrix configuration
|
|
using ElementB = half_t; // Element type for B matrix operand
|
|
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
|
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
|
|
|
// C/D matrix configuration
|
|
using ElementD = float; // Element type for D matrix operand
|
|
using ElementC = float; // Element type for C matrix operand
|
|
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
|
|
using LayoutTagD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
|
|
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
|
|
|
// Kernel functional config
|
|
using ElementAccumulator = float; // Element type for internal accumulation
|
|
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
|
using OperatorClass = cutlass::arch::OpClassSparseTensorOp; // Operator class tag
|
|
|
|
// MMA and Cluster Tile Shapes
|
|
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
|
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
|
// Shape of the threadblocks in a cluster
|
|
using ClusterShape_MNK = Shape<_2,_1,_1>;
|
|
|
|
// Build the epilogue
|
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
ArchTag, OperatorClass,
|
|
MmaTileShape_MNK, ClusterShape_MNK,
|
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
|
ElementAccumulator, ElementAccumulator,
|
|
ElementC, LayoutTagC, AlignmentC,
|
|
ElementD, LayoutTagD, AlignmentD,
|
|
cutlass::epilogue::TmaWarpSpecialized2Sm
|
|
>::CollectiveOp;
|
|
|
|
// Build the mainloop
|
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
|
ArchTag, OperatorClass,
|
|
ElementA, LayoutTagA, AlignmentA,
|
|
ElementB, LayoutTagB, AlignmentB,
|
|
ElementAccumulator,
|
|
MmaTileShape_MNK, ClusterShape_MNK,
|
|
cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>,
|
|
cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100
|
|
>::CollectiveOp;
|
|
|
|
using ProblemShape = Shape<int,int,int,int>;
|
|
|
|
// Compose into a kernel
|
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
|
ProblemShape,
|
|
CollectiveMainloop,
|
|
CollectiveEpilogue,
|
|
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
|
|
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
|
|
// Reference device GEMM implementation type
|
|
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
|
ElementA,
|
|
LayoutTagA,
|
|
ElementB,
|
|
LayoutTagB,
|
|
ElementC,
|
|
LayoutTagC,
|
|
ElementAccumulator,
|
|
ElementAccumulator>;
|
|
|
|
// Layouts
|
|
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
|
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
|
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
|
using StrideE = StrideA;
|
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
|
using StrideD = typename Gemm::GemmKernel::StrideD;
|
|
|
|
//
|
|
// Compressor
|
|
//
|
|
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
|
|
|
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
|
ProblemShape,
|
|
ElementA,
|
|
LayoutTagA,
|
|
SparseConfig>;
|
|
|
|
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
|
ProblemShape,
|
|
ElementA,
|
|
LayoutTagA,
|
|
SparseConfig,
|
|
ArchTag>;
|
|
|
|
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
/// Initialization
|
|
LayoutA layout_A;
|
|
LayoutE layout_E;
|
|
StrideA stride_A;
|
|
StrideA stride_A_compressed;
|
|
StrideE stride_E;
|
|
StrideB stride_B;
|
|
StrideC stride_C;
|
|
StrideD stride_D;
|
|
|
|
uint64_t seed;
|
|
|
|
ProblemShape problem_shape;
|
|
|
|
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
|
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
|
|
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
|
|
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
|
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
|
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
|
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
|
|
|
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// Testbed utility types
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Command line options parsing
|
|
struct Options {
|
|
|
|
bool help;
|
|
|
|
float alpha, beta;
|
|
int iterations;
|
|
int m, n, k, l;
|
|
|
|
Options():
|
|
help(false),
|
|
m(8192), n(8192), k(8192), l(1),
|
|
alpha(1.f), beta(0.f),
|
|
iterations(10)
|
|
{ }
|
|
|
|
// Parses the command line
|
|
void parse(int argc, char const **args) {
|
|
cutlass::CommandLine cmd(argc, args);
|
|
|
|
if (cmd.check_cmd_line_flag("help")) {
|
|
help = true;
|
|
return;
|
|
}
|
|
|
|
cmd.get_cmd_line_argument("m", m);
|
|
cmd.get_cmd_line_argument("n", n);
|
|
cmd.get_cmd_line_argument("k", k);
|
|
cmd.get_cmd_line_argument("l", l);
|
|
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
|
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
|
cmd.get_cmd_line_argument("iterations", iterations);
|
|
}
|
|
|
|
/// Prints the usage statement.
|
|
std::ostream & print_usage(std::ostream &out) const {
|
|
|
|
out << "83_blackwell_sparse_gemm\n\n"
|
|
<< " Blackwell FP16 Sparse GEMM example.\n\n"
|
|
<< "Options:\n\n"
|
|
<< " --help If specified, displays this usage statement\n\n"
|
|
<< " --m=<int> Sets the M extent of the GEMM\n"
|
|
<< " --n=<int> Sets the N extent of the GEMM\n"
|
|
<< " --k=<int> Sets the K extent of the GEMM\n"
|
|
<< " --l=<int> Sets the L extent of the GEMM\n"
|
|
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
|
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
|
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
|
|
|
out
|
|
<< "\n\nExamples:\n\n"
|
|
<< "$ " << "83_blackwell_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
|
|
|
return out;
|
|
}
|
|
|
|
/// Compute performance in GFLOP/s
|
|
double gflops(double runtime_s) const
|
|
{
|
|
// Two flops per multiply-add
|
|
uint64_t flop = uint64_t(2) * m * n * k;
|
|
double gflop = double(flop) / double(1.0e9);
|
|
return gflop / runtime_s;
|
|
}
|
|
};
|
|
|
|
/// Result structure
|
|
struct Result
|
|
{
|
|
double avg_runtime_ms;
|
|
double gflops;
|
|
cutlass::Status status;
|
|
cudaError_t error;
|
|
bool passed;
|
|
|
|
Result(
|
|
double avg_runtime_ms = 0,
|
|
double gflops = 0,
|
|
cutlass::Status status = cutlass::Status::kSuccess,
|
|
cudaError_t error = cudaSuccess)
|
|
:
|
|
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
|
{}
|
|
|
|
};
|
|
|
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// GEMM setup and evaluation
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Helper to initialize a block of device data
|
|
template <class Element>
|
|
bool initialize_block(
|
|
cutlass::DeviceAllocation<Element>& block,
|
|
uint64_t seed=2023) {
|
|
|
|
Element scope_max, scope_min;
|
|
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
|
|
|
if constexpr (bits_input == 1) {
|
|
scope_max = Element(2);
|
|
scope_min = Element(0);
|
|
}
|
|
else if constexpr (bits_input <= 8) {
|
|
scope_max = Element(2);
|
|
scope_min = Element(-2);
|
|
}
|
|
else {
|
|
scope_max = Element(8);
|
|
scope_min = Element(-8);
|
|
}
|
|
cutlass::reference::device::BlockFillRandomUniform(
|
|
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
|
return true;
|
|
}
|
|
|
|
/// Make A structured sparse by replacing elements with 0 and compress it
|
|
bool sparsify_and_compress()
|
|
{
|
|
auto [M, N, K, L] = problem_shape;
|
|
CompressorUtility compressor_utility(problem_shape, stride_A);
|
|
|
|
// TensorE
|
|
// In unit of ElementE (uint8_t), after alignment requirement
|
|
// M-dim: TensorEAtom_M alignment
|
|
// K-dim: TensorEAtom_K alignment
|
|
int KAlignedE = compressor_utility.get_metadata_k_physical();
|
|
int MAlignedE = compressor_utility.get_metadata_m_physical();
|
|
|
|
// TensorA Compressed
|
|
// In unit of ElementARaw, after alignment requirement
|
|
// M-dim: TMA alignment
|
|
// K-dim: TMA alignment
|
|
int KAlignedAC = compressor_utility.get_tensorA_k_physical();
|
|
int MAlignedAC = compressor_utility.get_tensorA_m_physical();
|
|
|
|
block_A_compressed.reset(M * KAlignedAC * L);
|
|
block_E.reset(MAlignedE * KAlignedE * L);
|
|
|
|
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L));
|
|
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L));
|
|
|
|
// Random 50% fill zero is performed on host
|
|
std::vector<ElementA> block_A_host(block_A.size());
|
|
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
|
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
|
|
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
|
|
|
cutlass::KernelHardwareInfo hw_info;
|
|
hw_info.device_id = 0;
|
|
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
|
typename Compressor::Arguments arguments {
|
|
problem_shape,
|
|
{ block_A.get(),
|
|
stride_A,
|
|
block_A_compressed.get(),
|
|
block_E.get() },
|
|
{hw_info} };
|
|
|
|
Compressor compressor_op;
|
|
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
cutlass::Status status {cutlass::Status::kSuccess };
|
|
status = compressor_op.can_implement(arguments);
|
|
if (status != cutlass::Status::kSuccess) {
|
|
return false;
|
|
}
|
|
|
|
status = compressor_op.initialize(arguments, workspace.get());
|
|
if (status != cutlass::Status::kSuccess) {
|
|
return false;
|
|
}
|
|
|
|
status = compressor_op.run();
|
|
if (status != cutlass::Status::kSuccess) {
|
|
return false;
|
|
}
|
|
|
|
auto result = cudaDeviceSynchronize();
|
|
if (result != cudaSuccess) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Initialize operands to be used in the GEMM and reference GEMM
|
|
bool initialize(const Options &options) {
|
|
|
|
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
|
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
|
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
|
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
|
|
|
block_A.reset(options.m * options.k);
|
|
block_B.reset(options.k * options.n);
|
|
block_C.reset(options.m * options.n);
|
|
block_D.reset(options.m * options.n);
|
|
block_ref_D.reset(options.m * options.n);
|
|
|
|
initialize_block(block_A, seed + 2023);
|
|
initialize_block(block_B, seed + 2022);
|
|
initialize_block(block_C, seed + 2021);
|
|
|
|
// Compress row A and get A_compress and E
|
|
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
|
if (not sparsify_and_compress()) {
|
|
return false;
|
|
};
|
|
|
|
// Build the compressed/metadata layouts
|
|
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
|
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Populates a Gemm::Arguments structure from the given commandline options
|
|
typename Gemm::Arguments args_from_options(const Options &options)
|
|
{
|
|
typename Gemm::Arguments arguments {
|
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
|
problem_shape,
|
|
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
|
|
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
|
};
|
|
|
|
return arguments;
|
|
}
|
|
|
|
bool verify(const Options &options) {
|
|
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
|
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
|
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
|
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
|
|
|
//
|
|
// Compute reference output
|
|
//
|
|
|
|
// Create instantiation for device reference gemm kernel
|
|
DeviceGemmReference gemm_reference;
|
|
|
|
// Launch device reference gemm kernel
|
|
gemm_reference(
|
|
{options.m, options.n, options.k},
|
|
ElementAccumulator(options.alpha),
|
|
ref_A,
|
|
ref_B,
|
|
ElementAccumulator(options.beta),
|
|
ref_C,
|
|
ref_D);
|
|
|
|
// Wait for kernel to finish
|
|
CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
|
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
|
|
|
return passed;
|
|
}
|
|
|
|
/// Execute a given example GEMM computation
|
|
template <typename Gemm>
|
|
int run(Options &options)
|
|
{
|
|
auto init_pass = initialize(options);
|
|
if (not init_pass) {
|
|
std::cout << "Initialization failure" << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
|
|
// Instantiate CUTLASS kernel depending on templates
|
|
Gemm gemm;
|
|
|
|
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
|
auto arguments = args_from_options(options);
|
|
|
|
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
|
|
// Allocate workspace memory
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
// Check if the problem size is supported or not
|
|
CUTLASS_CHECK(gemm.can_implement(arguments));
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
|
|
|
// Correctness / Warmup iteration
|
|
CUTLASS_CHECK(gemm.run());
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
|
Result result;
|
|
result.passed = verify(options);
|
|
|
|
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
|
|
|
if (not result.passed) {
|
|
exit(-1);
|
|
}
|
|
|
|
// Run profiling loop
|
|
if (options.iterations > 0)
|
|
{
|
|
GpuTimer timer;
|
|
timer.start();
|
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
|
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
|
CUTLASS_CHECK(gemm.run());
|
|
}
|
|
timer.stop();
|
|
|
|
// Compute average runtime and GFLOPs.
|
|
float elapsed_ms = timer.elapsed_millis();
|
|
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
|
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
|
|
|
|
|
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
|
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
|
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int main(int argc, char const **args) {
|
|
|
|
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
|
// and must have compute capability at least 100.
|
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
|
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
|
return 0;
|
|
}
|
|
|
|
cudaDeviceProp props;
|
|
int current_device_id;
|
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
|
if (not (props.major == 10 && props.minor == 0)) {
|
|
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// Parse options
|
|
//
|
|
|
|
Options options;
|
|
|
|
options.parse(argc, args);
|
|
|
|
if (options.help) {
|
|
options.print_usage(std::cout) << std::endl;
|
|
return 0;
|
|
}
|
|
|
|
//
|
|
// Evaluate CUTLASS kernels
|
|
//
|
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
run<Gemm>(options);
|
|
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
|
|
|
return 0;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|