@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes ?
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
|
||||
//
|
||||
|
||||
@ -231,7 +231,7 @@ struct B2bFusedGroupedGemmRun
|
||||
host_tensor_ref_D1.at(i).sync_device();
|
||||
|
||||
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
|
||||
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||
@ -340,7 +340,7 @@ struct B2bFusedGroupedGemmRun
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
host_tensor_D1.at(i).sync_host();;
|
||||
host_tensor_D1.at(i).sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
|
||||
@ -194,7 +194,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -33,6 +33,11 @@ cutlass_example_add_executable(
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm_universal
|
||||
ampere_sparse_tensorop_gemm_universal.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm_with_visitor
|
||||
ampere_sparse_tensorop_gemm_with_visitor.cu
|
||||
|
||||
@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -0,0 +1,329 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||
|
||||
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||
input A and B. Check code below.
|
||||
|
||||
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||
efficiently.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse_universal.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmSparseUniversal<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::ElementE;
|
||||
using LayoutInputE = cutlass::layout::RowMajor;
|
||||
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
1,
|
||||
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||
// instructions.
|
||||
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||
{problem_size.m(), problem_size.n(),
|
||||
problem_size.k() / kSparse / kElementsPerElementE});
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_e_reordered.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 2;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size, // <- problem size of matrix multiplication
|
||||
split_k_slices,// <- k-dimension split factor
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
tensor_a.device_data(), // <- reference to matrix A on device
|
||||
tensor_b.device_data(), // <- reference to matrix B on device
|
||||
tensor_c.device_data(), // <- reference to matrix C on device
|
||||
tensor_d.device_data(), // <- reference to matrix D on device
|
||||
tensor_e_reordered.device_data(), // <- reference to matrix E on device
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
int64_t(),
|
||||
tensor_a.layout().stride(0),
|
||||
tensor_b.layout().stride(0),
|
||||
tensor_c.layout().stride(0),
|
||||
tensor_d.layout().stride(0),
|
||||
tensor_e_reordered.layout().stride(0)
|
||||
};
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||
|
||||
// Create instantiation for host reference gemm kernel
|
||||
cutlass::reference::host::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
typename Gemm::Operator>
|
||||
gemm_host;
|
||||
|
||||
// Launch host reference gemm kernel
|
||||
gemm_host(problem_size,
|
||||
alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
beta,
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref());
|
||||
|
||||
// Copy output data from CUTLASS host for comparison
|
||||
tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major * 10 + props.minor < 80) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
@ -94,7 +94,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 1
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using Operator = cutlass::arch::OpMultiplyAddSaturate;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
@ -138,7 +138,7 @@ using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
|
||||
>;
|
||||
|
||||
// Below is the reduction kernel used in the case of parallel split-k
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;
|
||||
|
||||
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
|
||||
ElementAccumulator,
|
||||
@ -154,7 +154,7 @@ using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
|
||||
|
||||
using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK<ReduceGemmSplitKKernel>;
|
||||
|
||||
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;;
|
||||
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -258,7 +258,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -221,7 +221,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -193,7 +193,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -215,7 +215,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8
|
||||
// 16, 8, 16 -> Ampere
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Define the epilogue operation as LinearCombination. This is approximately equal to
|
||||
//
|
||||
|
||||
@ -118,7 +118,7 @@ operation = Conv2dOperation(
|
||||
conv_kind=cutlass_bindings.conv.Operator.fprop,
|
||||
iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Unity,
|
||||
epilogue_functor=epilogue_functor
|
||||
)
|
||||
|
||||
|
||||
@ -1799,7 +1799,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
: cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
@ -1938,7 +1938,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
: cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
|
||||
@ -35,19 +35,23 @@
|
||||
This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and
|
||||
convert them implicitly to TF32.
|
||||
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048
|
||||
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -63,6 +67,7 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
@ -105,7 +110,7 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
|
||||
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
@ -175,6 +180,8 @@ cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
@ -183,12 +190,16 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(5120), n(4096), k(4096),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(1000)
|
||||
iterations(1000),
|
||||
raster(RasterOrderOptions::Heuristic),
|
||||
swizzle(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
@ -206,6 +217,21 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -220,6 +246,8 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
@ -294,10 +322,10 @@ bool initialize_block(
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
@ -320,6 +348,10 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
@ -408,7 +440,17 @@ int run(Options &options)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
@ -441,7 +483,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -538,9 +538,8 @@ int main(int argc, char const **args) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -354,9 +354,8 @@ int main(int argc, char const **args) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -627,7 +627,6 @@ int main(int argc, const char ** argv) {
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ public:
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
||||
// swap M/N
|
||||
get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
@ -181,8 +181,7 @@ public:
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE static
|
||||
bool
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
|
||||
@ -119,7 +119,7 @@ public:
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool
|
||||
static bool
|
||||
can_implement(
|
||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
|
||||
@ -750,7 +750,6 @@ int main(int argc, char const **argv)
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -47,9 +47,13 @@
|
||||
4. This example shows all important fusions used by FP8 gemm kernels,
|
||||
i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor.
|
||||
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048
|
||||
$ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -63,6 +67,7 @@
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
@ -214,6 +219,8 @@ cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -273,7 +280,7 @@ bool initialize_tensor(
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
@ -346,7 +353,7 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
@ -392,10 +399,14 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool verify(const Options<RasterOrderOptions> &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
@ -468,7 +479,7 @@ bool verify(const Options &options) {
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
int run(Options<RasterOrderOptions> &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
@ -518,7 +529,17 @@ int run(Options &options)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
@ -551,12 +572,11 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
Options<RasterOrderOptions> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename RasterOrderOptions>
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
@ -41,6 +42,8 @@ struct Options {
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -66,6 +69,21 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
@ -89,6 +107,8 @@ struct Options {
|
||||
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
|
||||
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
|
||||
@ -687,7 +687,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -99,7 +99,7 @@ using TileShape = Shape<_256,_128,_64>; // T
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
@ -492,7 +492,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
|
||||
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=1) # Default problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
|
||||
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
/*! \file
|
||||
\brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
@ -42,7 +42,7 @@
|
||||
|
||||
$ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
@ -117,7 +117,7 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
@ -163,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
@ -226,7 +226,7 @@ struct Options {
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -438,10 +438,10 @@ void allocate(const Options &options) {
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
|
||||
}
|
||||
|
||||
@ -456,7 +456,7 @@ void allocate(const Options &options) {
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
@ -695,7 +695,6 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -97,7 +97,7 @@ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWith
|
||||
cutlass::epilogue::thread::ReLu,
|
||||
ElementOutput,
|
||||
ElementAuxOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
8,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>;
|
||||
@ -106,7 +106,7 @@ template <typename MathOperator>
|
||||
using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89,
|
||||
cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages,
|
||||
kAlignmentA, kAlignmentB, MathOperator
|
||||
>;
|
||||
|
||||
@ -53,3 +53,8 @@ cutlass_example_add_executable(
|
||||
tiled_copy.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
wgmma_sm90
|
||||
wgmma_sm90.cu
|
||||
)
|
||||
|
||||
|
||||
@ -153,12 +153,12 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
@ -358,7 +358,7 @@ gemm_nt(int m, int n, int k,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
// Setup params for a TN GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
@ -391,10 +391,10 @@ gemm_tn(int m, int n, int k,
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA_atom = make_layout(make_shape ( bM, bK),
|
||||
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
|
||||
auto sB_atom = make_layout(make_shape ( bN, bK),
|
||||
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
|
||||
auto sA_atom = make_layout(make_shape ( bM, bK),
|
||||
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
|
||||
[[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK),
|
||||
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
|
||||
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
|
||||
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
|
||||
|
||||
562
examples/cute/tutorial/wgmma_sm90.cu
Normal file
562
examples/cute/tutorial/wgmma_sm90.cu
Normal file
@ -0,0 +1,562 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <class ElementA,
|
||||
class ElementB,
|
||||
class SmemLayoutA, // (M,K,P)
|
||||
class SmemLayoutB> // (N,K,P)
|
||||
struct SharedStorage
|
||||
{
|
||||
array_aligned<ElementA, cosize_v<SmemLayoutA>> smem_A;
|
||||
array_aligned<ElementB, cosize_v<SmemLayoutB>> smem_B;
|
||||
|
||||
uint64_t tma_barrier[size<2>(SmemLayoutA{})];
|
||||
uint64_t mma_barrier[size<2>(SmemLayoutA{})];
|
||||
};
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class SmemLayoutA, class TmaA,
|
||||
class TB, class SmemLayoutB, class TmaB,
|
||||
class TC, class CStride, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, CUTLASS_GRID_CONSTANT TmaA const tma_a,
|
||||
TB const* B, CUTLASS_GRID_CONSTANT TmaB const tma_b,
|
||||
TC * C, CStride dC, TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
static_assert(is_static<SmemLayoutA>::value);
|
||||
static_assert(is_static<SmemLayoutB>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutA{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutB{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutA{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutB{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
auto [M, N, K] = shape_MNK;
|
||||
Tensor mA = tma_a.get_tma_tensor(make_shape(M,K)); // (M,K) TMA Tensor
|
||||
Tensor mB = tma_b.get_tma_tensor(make_shape(N,K)); // (N,K) TMA Tensor
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory tensors
|
||||
extern __shared__ char shared_memory[];
|
||||
using SharedStorage = SharedStorage<TA, TB, SmemLayoutA, SmemLayoutB>;
|
||||
SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles
|
||||
//
|
||||
// TUTORIAL:
|
||||
// These are TMA partitionings, which have a dedicated custom partitioner.
|
||||
// The Int<0>, Layout<_1> indicates that the TMAs are not multicasted.
|
||||
// Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host.
|
||||
// The group_modes<0,2> transforms the (X,Y,Z)-shaped tensors into ((X,Y),Z)-shaped tensors
|
||||
// with the understanding that the TMA is responsible for everything in mode-0.
|
||||
// The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info.
|
||||
//
|
||||
|
||||
auto [tAgA, tAsA] = tma_partition(tma_a, Int<0>{}, Layout<_1>{},
|
||||
group_modes<0,2>(sA), group_modes<0,2>(gA)); // (TMA,k) and (TMA,PIPE)
|
||||
|
||||
auto [tBgB, tBsB] = tma_partition(tma_b, Int<0>{}, Layout<_1>{},
|
||||
group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE)
|
||||
|
||||
// The TMA is responsible for copying everything in mode-0 of tAsA and tBsB
|
||||
constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) +
|
||||
CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB);
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
auto K_PIPE_MAX = size<1>(tAsA);
|
||||
|
||||
// Total count of tiles
|
||||
int k_tile_count = size<1>(tAgA);
|
||||
// Current tile index in gmem to read from
|
||||
int k_tile = 0;
|
||||
|
||||
// Initialize Barriers
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint64_t* producer_mbar = smem.tma_barrier;
|
||||
uint64_t* consumer_mbar = smem.mma_barrier;
|
||||
|
||||
using ProducerBarType = cutlass::arch::ClusterTransactionBarrier; // TMA
|
||||
using ConsumerBarType = cutlass::arch::ClusterBarrier; // MMA
|
||||
CUTE_UNROLL
|
||||
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) {
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
ProducerBarType::init(&producer_mbar[pipe], 1);
|
||||
ConsumerBarType::init(&consumer_mbar[pipe], 128);
|
||||
}
|
||||
}
|
||||
// Ensure barrier init is complete on all CTAs
|
||||
cluster_sync();
|
||||
|
||||
// Start async loads for all pipes
|
||||
CUTE_UNROLL
|
||||
for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe)
|
||||
{
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
// Set expected Tx Bytes after each reset / init
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
|
||||
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
|
||||
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
|
||||
}
|
||||
--k_tile_count;
|
||||
++k_tile;
|
||||
}
|
||||
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
// TUTORIAL:
|
||||
// The tCrA and tCrB are actually Tensors of MMA Descriptors constructed as views of SMEM.
|
||||
// The MMA Descriptor generation is automatic via inspection and validation of the SMEM Layouts.
|
||||
// Because the MMA reads directly from SMEM and the fragments are descriptors rather than registers,
|
||||
// there is no need for copy(tCsA, tCrA) in the mainloop.
|
||||
//
|
||||
|
||||
ThrMMA thr_mma = mma.get_thread_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate accumulators and clear them
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
clear(tCrC);
|
||||
|
||||
// Allocate "fragments"
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
// TUTORIAL:
|
||||
// Rather than interleaving the stages and instructions like in SM70 and SM80,
|
||||
// the SM90 mainloops rely on explicit producer-consumer synchronization
|
||||
// on the purely async instructions TMA and MMA.
|
||||
// More advanced pipeline and warp-specialization strategies are available in CUTLASS mainloops.
|
||||
//
|
||||
|
||||
// A PipelineState is a circular pipe index [.index()] and a pipe phase [.phase()]
|
||||
// that flips each cycle through K_PIPE_MAX.
|
||||
auto write_state = cutlass::PipelineState<K_PIPE_MAX>(); // TMA writes
|
||||
auto read_state = cutlass::PipelineState<K_PIPE_MAX>(); // MMA reads
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
while (k_tile_count > -K_PIPE_MAX)
|
||||
{
|
||||
// Wait for Producer to complete
|
||||
int read_pipe = read_state.index();
|
||||
ProducerBarType::wait(&producer_mbar[read_pipe], read_state.phase());
|
||||
|
||||
// MMAs to cover 1 K_TILE
|
||||
warpgroup_arrive();
|
||||
gemm(mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC); // (V,M) x (V,N) => (V,M,N)
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Wait for all MMAs in a K_TILE to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify that consumption is done
|
||||
ConsumerBarType::arrive(&consumer_mbar[read_pipe]);
|
||||
++read_state;
|
||||
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
int pipe = write_state.index();
|
||||
// Wait for Consumer to complete consumption
|
||||
ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase());
|
||||
// Set expected Tx Bytes after each reset / init
|
||||
ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes);
|
||||
copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe));
|
||||
copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe));
|
||||
++write_state;
|
||||
}
|
||||
--k_tile_count;
|
||||
++k_tile;
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue (unpredicated)
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
// Setup params for an NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int< 3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
|
||||
auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
|
||||
|
||||
// Define the MMA
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});
|
||||
|
||||
// Define the TMAs
|
||||
// Create Global memory tensors for TMA inspection
|
||||
Tensor mA = make_tensor(A, make_shape(M,K), dA);
|
||||
Tensor mB = make_tensor(B, make_shape(N,K), dB);
|
||||
|
||||
// Create TMA Atoms with the desired copy operation on the source and destination
|
||||
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
|
||||
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
|
||||
|
||||
//
|
||||
// Setup and Launch
|
||||
//
|
||||
|
||||
// Launch parameter setup
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(tiled_mma));
|
||||
dim3 dimCluster(2, 1, 1);
|
||||
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(n, bN)), dimCluster.y));
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
|
||||
|
||||
void const* kernel_ptr = reinterpret_cast<void const*>(
|
||||
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(sA), decltype(tmaA),
|
||||
TB, decltype(sB), decltype(tmaB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>);
|
||||
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
|
||||
kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
|
||||
// Kernel Launch
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
|
||||
prob_shape, cta_tiler,
|
||||
A, tmaA,
|
||||
B, tmaB,
|
||||
C, dC, tiled_mma,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup params for a TN GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 64>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom<TA>{}, make_shape(bM,bK,bP));
|
||||
auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom<TB>{}, make_shape(bN,bK,bP));
|
||||
|
||||
// Define the MMA
|
||||
TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS<GMMA::Major::K,GMMA::Major::K>{});
|
||||
|
||||
// Define the TMAs
|
||||
// Create Global memory tensors for TMA inspection
|
||||
Tensor mA = make_tensor(A, make_shape(M,K), dA);
|
||||
Tensor mB = make_tensor(B, make_shape(N,K), dB);
|
||||
|
||||
// Create TMA Atoms with the desired copy operation on the source and destination
|
||||
Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK));
|
||||
Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK));
|
||||
|
||||
//
|
||||
// Setup and Launch
|
||||
//
|
||||
|
||||
// Launch parameter setup
|
||||
int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
|
||||
dim3 dimBlock(size(tiled_mma));
|
||||
dim3 dimCluster(2, 1, 1);
|
||||
dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x),
|
||||
round_up(size(ceil_div(n, bN)), dimCluster.y));
|
||||
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size};
|
||||
|
||||
void const* kernel_ptr = reinterpret_cast<void const*>(
|
||||
&gemm_device<decltype(prob_shape), decltype(cta_tiler),
|
||||
TA, decltype(sA), decltype(tmaA),
|
||||
TB, decltype(sB), decltype(tmaB),
|
||||
TC, decltype(dC), decltype(tiled_mma),
|
||||
decltype(alpha), decltype(beta)>);
|
||||
|
||||
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
|
||||
kernel_ptr,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size));
|
||||
|
||||
// Kernel Launch
|
||||
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr,
|
||||
prob_shape, cta_tiler,
|
||||
A, tmaA,
|
||||
B, tmaB,
|
||||
C, dC, tiled_mma,
|
||||
alpha, beta);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Error: Failed at kernel Launch" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
|
||||
cudaDeviceProp props;
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major != 9) {
|
||||
std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a\n" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
int m = 512;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 256;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 1024;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = cute::half_t;
|
||||
using TB = cute::half_t;
|
||||
using TC = cute::half_t;
|
||||
using TI = cute::half_t;
|
||||
|
||||
TI alpha = TI(1.0f);
|
||||
TI beta = TI(0.0f);
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
// Initialize the tensors
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1));
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1));
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = TC(0);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#else
|
||||
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user