Blockwise and Groupwise GEMM for Blackwell and Improvements for Hopper (#2139)
- Blockwise and Groupwise GEMM improvements for Hopper. - Blockwise and Groupwise GEMM for Blackwell. - Blockwise Grouped GEMM for Hopper. - Static ScalePromotionInterval for Hopper FP8 GEMMs. Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
@ -398,6 +398,10 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
@ -662,9 +666,11 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
if (options.verify) {
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
@ -674,8 +680,9 @@ int run(Options<RasterOrderOptions> &options)
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
@ -453,6 +453,10 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
@ -668,14 +672,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
@ -729,13 +733,15 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
if (options.verify) {
|
||||
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
|
||||
@ -34,6 +34,7 @@ template<typename RasterOrderOptions>
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verify = true;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
||||
@ -41,6 +42,7 @@ struct Options {
|
||||
bool save_aux = true;
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
@ -68,7 +70,9 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("verify", verify);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
@ -89,8 +93,8 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
@ -113,7 +117,7 @@ struct Options {
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -0,0 +1,841 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
|
||||
descriptors to move between groups/problem_count (represented by groups).
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
|
||||
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
#include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
|
||||
// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor
|
||||
// Given TileShape = Shape<_128,_128,_128>:
|
||||
// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor)
|
||||
// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling
|
||||
// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling
|
||||
// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling
|
||||
template <int ScaleGranularityM_, int ScaleGranularityN_>
|
||||
struct GroupScaleConfig {
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
|
||||
"FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile,
|
||||
"FP8 scaling granularity must evenly divide tile shape along N.");
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
};
|
||||
|
||||
using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>;
|
||||
using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>;
|
||||
using GroupScale2D1DConfig = GroupScaleConfig<size<0>(TileShape_{}), 1>;
|
||||
using GroupScale2D2DConfig = GroupScaleConfig<size<0>(TileShape_{}), size<1>(TileShape_{})>;
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GroupScaleGemm {
|
||||
using ArchTag = typename ScheduleConfig::ArchTag;
|
||||
using OperatorClass = typename ScheduleConfig::OperatorClass;
|
||||
using TileShape = typename ScheduleConfig::TileShape;
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape;
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule;
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using FusionOperation = typename ScheduleConfig::FusionOperation;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GroupScale1D1DGemm = GroupScaleGemm<GroupScale1D1DConfig>;
|
||||
using GroupScale1D2DGemm = GroupScaleGemm<GroupScale1D2DConfig>;
|
||||
using GroupScale2D1DGemm = GroupScaleGemm<GroupScale2D1DConfig>;
|
||||
using GroupScale2D2DGemm = GroupScaleGemm<GroupScale2D2DConfig>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename GroupScale1D1DGemm::Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_blockscale_A;
|
||||
std::vector<int64_t> offset_blockscale_B;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
|
||||
|
||||
cutlass::DeviceAllocation<const ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023,
|
||||
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
|
||||
|
||||
double _scope_max, _scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if (bits_input == 1) {
|
||||
_scope_max = 2;
|
||||
_scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
_scope_max = 2;
|
||||
_scope_min = -2;
|
||||
} else if (bits_input == 16) {
|
||||
_scope_max = 5;
|
||||
_scope_min = -5;
|
||||
} else {
|
||||
_scope_max = 8;
|
||||
_scope_min = -8;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
|
||||
_scope_max = scope_max;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
|
||||
_scope_min = scope_min;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
template <typename OptionType>
|
||||
void allocate(const OptionType &options) {
|
||||
|
||||
using TileShape = typename OptionType::GroupScaleConfig::TileShape;
|
||||
const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_blockscale_A = 0;
|
||||
int64_t total_elements_blockscale_B = 0;
|
||||
|
||||
offset_A.clear();
|
||||
offset_B.clear();
|
||||
offset_C.clear();
|
||||
offset_D.clear();
|
||||
offset_blockscale_A.clear();
|
||||
offset_blockscale_B.clear();
|
||||
stride_A_host.clear();
|
||||
stride_B_host.clear();
|
||||
stride_C_host.clear();
|
||||
stride_D_host.clear();
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
|
||||
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_blockscale_A.push_back(total_elements_blockscale_A);
|
||||
offset_blockscale_B.push_back(total_elements_blockscale_B);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_blockscale_A = groupscale_m * blockscale_k;
|
||||
int64_t elements_blockscale_B = groupscale_n * blockscale_k;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_blockscale_A += elements_blockscale_A;
|
||||
total_elements_blockscale_B += elements_blockscale_B;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
blockscale_block_A.reset(total_elements_blockscale_A);
|
||||
blockscale_block_B.reset(total_elements_blockscale_B);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename OptionType>
|
||||
void initialize(const OptionType &options) {
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
|
||||
|
||||
alpha_host.clear();
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_blockscale_A.reset(options.groups);
|
||||
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
|
||||
|
||||
ptr_blockscale_B.reset(options.groups);
|
||||
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2022);
|
||||
initialize_block(block_B, seed + 2023);
|
||||
initialize_block(block_C, seed + 2024);
|
||||
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
|
||||
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments, typename OptionType>
|
||||
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
int device_id = 0;
|
||||
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename GroupScale1D1DGemm::Gemm::GemmKernel>(device_id);
|
||||
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_blockscale_A.get(),
|
||||
ptr_blockscale_B.get()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
kernel_hw_info
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename OptionType>
|
||||
bool verify(const OptionType &options) {
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
std::vector<ElementB> block_B_host(block_B.size());
|
||||
std::vector<ElementC> block_C_host(block_C.size());
|
||||
std::vector<ElementD> block_D_host_kernel(block_D.size());
|
||||
std::vector<ElementD> block_D_host_ref(block_D.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
|
||||
|
||||
block_A.copy_to_host(block_A_host.data());
|
||||
block_B.copy_to_host(block_B_host.data());
|
||||
block_C.copy_to_host(block_C_host.data());
|
||||
block_D.copy_to_host(block_D_host_kernel.data());
|
||||
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = blockscale_m * OptionType::GroupScaleConfig::ScaleMsPerTile;
|
||||
auto groupscale_n = blockscale_n * OptionType::GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, k, 1),
|
||||
stride_A_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(n, k, 1),
|
||||
stride_B_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_C_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_D_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
|
||||
auto blockscale_A = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(groupscale_m, blockscale_k, 1),
|
||||
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(groupscale_n, blockscale_k, 1),
|
||||
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(B),
|
||||
decltype(blockscale_A),
|
||||
decltype(blockscale_B),
|
||||
TileShape_
|
||||
> mainloop_params{
|
||||
A, B, // Operand Tensors
|
||||
blockscale_A, blockscale_B // Groupwise scaling Tensors
|
||||
};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // Aux
|
||||
unused_t, // valpha
|
||||
unused_t, // vbeta
|
||||
ActivationFunctor
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(group_idx);
|
||||
epilogue_params.beta = beta_host.at(group_idx);
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
auto this_group_passed = std::equal(
|
||||
// std::execution::par_unseq,
|
||||
block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
|
||||
block_D_host_kernel.data() + offset_D.at(group_idx)
|
||||
);
|
||||
|
||||
passed &= this_group_passed;
|
||||
|
||||
#if 0
|
||||
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm, typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
using TileShape = typename OptionType::GroupScaleConfig::TileShape;
|
||||
const int ScaleGranularityM = OptionType::GroupScaleConfig::ScaleGranularityM;
|
||||
const int ScaleGranularityN = OptionType::GroupScaleConfig::ScaleGranularityN;
|
||||
const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile;
|
||||
const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile;
|
||||
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape, GroupScale1D1DConfig> options_1d1d;
|
||||
Options<RasterOrderOptions, ProblemShape, GroupScale1D2DConfig> options_1d2d;
|
||||
Options<RasterOrderOptions, ProblemShape, GroupScale2D1DConfig> options_2d1d;
|
||||
Options<RasterOrderOptions, ProblemShape, GroupScale2D2DConfig> options_2d2d;
|
||||
|
||||
options_1d1d.parse(argc, args);
|
||||
options_1d2d.parse(argc, args);
|
||||
options_2d1d.parse(argc, args);
|
||||
options_2d2d.parse(argc, args);
|
||||
|
||||
if (options_1d1d.help) {
|
||||
options_1d1d.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
auto run_tests = [&] (bool host_problem_shapes_available = true) {
|
||||
std::cout << "Grouped GEMM kernel with 1D1D group scale" << std::endl;
|
||||
run<GroupScale1D1DGemm::Gemm>(options_1d1d, host_problem_shapes_available);
|
||||
std::cout << "Grouped GEMM kernel with 1D2D group scale" << std::endl;
|
||||
run<GroupScale1D2DGemm::Gemm>(options_1d2d, host_problem_shapes_available);
|
||||
std::cout << "Grouped GEMM kernel with 2D1D group scale" << std::endl;
|
||||
run<GroupScale2D1DGemm::Gemm>(options_2d1d, host_problem_shapes_available);
|
||||
std::cout << "Grouped GEMM kernel with 2D2D group scale" << std::endl;
|
||||
run<GroupScale2D2DGemm::Gemm>(options_2d2d, host_problem_shapes_available);
|
||||
std::cout << std::endl;
|
||||
};
|
||||
|
||||
std::cout << "Running tests with host problem shapes:" << std::endl;
|
||||
run_tests(true);
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run_tests(false);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
)
|
||||
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename _RasterOrderOptions, typename _ProblemShape, typename _GroupScaleConfig>
|
||||
struct Options {
|
||||
|
||||
using RasterOrderOptions = _RasterOrderOptions;
|
||||
using ProblemShape = _ProblemShape;
|
||||
using GroupScaleConfig = _GroupScaleConfig;
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, groups = 10;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
int const k_alignment = 128;
|
||||
int const m_alignment = 128;
|
||||
int const n_alignment = 128;
|
||||
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = 0ull;
|
||||
|
||||
for (auto const [m, n, k] : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(m) *
|
||||
static_cast<uint64_t>(n) *
|
||||
static_cast<uint64_t>(k);
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,520 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GETT in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
#include <iostream>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
template<class T, class = void>
|
||||
struct ElementTraits {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
||||
using type = decltype(std::declval<T>().get());
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorScaleA_, // (m, k, L)
|
||||
class TensorScaleB_, // (n, k, L)
|
||||
class TileShape_
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TensorA = TensorA_;
|
||||
using TensorB = TensorB_;
|
||||
using EngineA = typename TensorA::engine_type;
|
||||
using LayoutA = typename TensorA::layout_type;
|
||||
using EngineB = typename TensorB::engine_type;
|
||||
using LayoutB = typename TensorB::layout_type;
|
||||
|
||||
using TensorScaleA = TensorScaleA_;
|
||||
using TensorScaleB = TensorScaleB_;
|
||||
using TileShape = TileShape_;
|
||||
using EngineScaleA = typename TensorScaleA::engine_type;
|
||||
using EngineScaleB = typename TensorScaleB::engine_type;
|
||||
|
||||
TensorA A{};
|
||||
TensorB B{};
|
||||
TensorScaleA ScaleA{};
|
||||
TensorScaleB ScaleB{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementScalingFactor = ElementScalingFactor_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TensorC = TensorC_;
|
||||
using TensorD = TensorD_;
|
||||
using TensorAux = TensorAux_;
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
using EngineC = typename TensorC::engine_type;
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorC C{};
|
||||
TensorD D{};
|
||||
VectorBias Bias{};
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
|
||||
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gett(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
|
||||
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
|
||||
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
|
||||
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
|
||||
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
||||
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
||||
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
||||
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
||||
gett_mainloop(mainloop_params, m, n, l, acc);
|
||||
gett_epilogue(epilogue_params, m, n, l, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Mainloop
|
||||
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_mainloop(
|
||||
MainloopParams const& mainloop_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
|
||||
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
multiplies<ElementAccumulator> scale_op;
|
||||
|
||||
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
|
||||
|
||||
// Tempo accumulators to seperate blockwise accumulation
|
||||
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
|
||||
|
||||
// Zero out accumulators
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
|
||||
const int M = cute::size<0>(mainloop_params.A.layout());
|
||||
const int N = cute::size<0>(mainloop_params.B.layout());
|
||||
|
||||
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA.layout());
|
||||
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB.layout());
|
||||
|
||||
assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M");
|
||||
assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N");
|
||||
|
||||
cute::Tensor blockscale_A = domain_offset(make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
|
||||
cute::Tensor blockscale_B = domain_offset(make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
|
||||
// Load Blockwise scaling factor from blockscale Tensors for B
|
||||
int64_t block_k = k / kBlockK;
|
||||
cute::Tensor scale_a = blockscale_A(_, block_k);
|
||||
cute::Tensor scale_b = blockscale_B(_, block_k);
|
||||
|
||||
// Load A
|
||||
ElementAccumulator a_frag[kBlockM];
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
} else {
|
||||
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// Load B
|
||||
ElementAccumulator b_frag[kBlockN];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
} else {
|
||||
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Groupwise-scaling at kBlockK boundary
|
||||
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Epilogue
|
||||
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_epilogue(
|
||||
EpilogueParams const& epilogue_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
||||
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
||||
|
||||
// Abs max converter
|
||||
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
||||
multiplies<ElementCompute> mul;
|
||||
plus<ElementCompute> add;
|
||||
|
||||
// Activation operation
|
||||
|
||||
auto activation = [] (ElementCompute x, ElementCompute y = ElementCompute(0)) {
|
||||
if constexpr (std::is_same_v<ActivationFunctor, void>) {
|
||||
return x + y;
|
||||
} else {
|
||||
return ActivationFunctor()(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
// Bias binary operation
|
||||
BiasBinaryOp bias_op;
|
||||
|
||||
// Do conversion
|
||||
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
||||
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
||||
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
||||
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
||||
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
||||
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
||||
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
||||
|
||||
// Init local var
|
||||
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
||||
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
||||
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
|
||||
}
|
||||
ElementCompute output = mul(converted_alpha, converted_acc);
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
||||
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
||||
output = bias_op(output, converted_bias);
|
||||
}
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
|
||||
}
|
||||
output = epilogue_fma(converted_beta, converted_src, output);
|
||||
}
|
||||
|
||||
if constexpr (IsBackpropFusion) {
|
||||
ElementAux aux_input = ElementAux(0);
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
||||
}
|
||||
|
||||
output = activation(output, aux_source_converter(aux_input));
|
||||
local_dBias = add(local_dBias, output);
|
||||
}
|
||||
else {
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
auto aux_output = output;
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
||||
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
||||
}
|
||||
|
||||
if constexpr (IsReLUAuxNeeded) {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
||||
} else {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_output = amax_op(local_abs_max_output, output);
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
||||
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
||||
local_dBias = add(local_dBias, converted_dBias);
|
||||
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
{
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_D) {
|
||||
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_Aux) {
|
||||
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gemm3x(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
|
||||
"with Batchmode are supported");
|
||||
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
|
||||
Gett(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,585 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
|
||||
using ElementCompute = float;
|
||||
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
// Strides just iterate over scalars and have no zeros
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
// Layouts are tiled to the problem size and the strides have zeros
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
|
||||
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
|
||||
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool skip_verification = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("skip-verification")) {
|
||||
skip_verification = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "81_blackwell_gemm_blockwise\n\n"
|
||||
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --skip-verification Skip verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -8;
|
||||
scope_max = 8;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
tensor_SFA.resize(blockscale_a_coord);
|
||||
tensor_SFB.resize(blockscale_b_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
|
||||
|
||||
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
|
||||
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A,
|
||||
tensor_B.device_data(), stride_B,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
Result result;
|
||||
if (!options.skip_verification) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least sm100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,589 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
|
||||
using ElementCompute = float;
|
||||
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a
|
||||
// 16B alignment if possible (i.e., we have at least 16B of scales in M).
|
||||
// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B
|
||||
// and transpose A, B, C, and scales since B^T A^T = C^T.
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
// Strides just iterate over scalars and have no zeros
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
// Layouts are tiled to the problem size and the strides have zeros
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
|
||||
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
|
||||
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool skip_verification = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("skip-verification")) {
|
||||
skip_verification = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "81_blackwell_gemm_groupwise\n\n"
|
||||
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --skip-verification Skip verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "81_blackwell_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -8;
|
||||
scope_max = 8;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
tensor_SFA.resize(blockscale_a_coord);
|
||||
tensor_SFB.resize(blockscale_b_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
|
||||
|
||||
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
|
||||
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A,
|
||||
tensor_B.device_data(), stride_B,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
Result result;
|
||||
if (!options.skip_verification) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least sm100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
57
examples/81_blackwell_gemm_blockwise/CMakeLists.txt
Normal file
57
examples/81_blackwell_gemm_blockwise/CMakeLists.txt
Normal file
@ -0,0 +1,57 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --k=128 --iterations=0) # Small problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
81_blackwell_gemm_blockwise
|
||||
81_blackwell_gemm_blockwise.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_EPILOGUE
|
||||
TEST_SMALL
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
81_blackwell_gemm_groupwise
|
||||
81_blackwell_gemm_groupwise.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_EPILOGUE
|
||||
TEST_SMALL
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -146,6 +146,7 @@ foreach(EXAMPLE
|
||||
64_ada_fp8_gemm_grouped
|
||||
65_distributed_gemm
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
70_blackwell_gemm
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
@ -156,6 +157,7 @@ foreach(EXAMPLE
|
||||
76_blackwell_conv
|
||||
77_blackwell_fmha
|
||||
78_blackwell_emulated_bf16x9_gemm
|
||||
81_blackwell_gemm_blockwise
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
189
include/cutlass/detail/sm100_blockwise_scale_layout.hpp
Normal file
189
include/cutlass/detail/sm100_blockwise_scale_layout.hpp
Normal file
@ -0,0 +1,189 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cute/int_tuple.hpp"
|
||||
#include "cute/atom/mma_traits_sm100.hpp"
|
||||
|
||||
namespace cutlass::detail{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using namespace cute;
|
||||
|
||||
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
|
||||
struct Sm100BlockwiseScaleConfig {
|
||||
|
||||
using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
||||
using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
||||
|
||||
using StrideSFA = conditional_t<majorSFA == UMMA::Major::MN,
|
||||
Stride<Stride<_0,_1>,Stride<_0,int32_t>, int32_t>,
|
||||
Stride<Stride<_0,int32_t>,Stride<_0,_1>, int32_t>>;
|
||||
|
||||
using StrideSFB = conditional_t<majorSFB == UMMA::Major::MN,
|
||||
Stride<Stride<_0,_1>,Stride<_0,int32_t>, int32_t>,
|
||||
Stride<Stride<_0,int32_t>,Stride<_0,_1>, int32_t>>;
|
||||
|
||||
using LayoutSFA = Layout<ShapeSFA, StrideSFA>;
|
||||
using LayoutSFB = Layout<ShapeSFB, StrideSFB>;
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
deduce_layoutSFA() {
|
||||
return LayoutSFA{};
|
||||
}
|
||||
|
||||
template<typename CtaShape_MNK>
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
smem_atom_layoutSFA(CtaShape_MNK cta_shape_mnk) {
|
||||
static_assert(cute::is_static_v<CtaShape_MNK>, "Expect static CTA shape");
|
||||
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
auto [M, N, K] = cta_shape_mnk;
|
||||
if constexpr (majorSFA == UMMA::Major::MN) {
|
||||
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int<cute::ceil_div(size<0>(CtaShape_MNK{}), SFVecSizeM)>{}));
|
||||
}
|
||||
else {
|
||||
return make_stride(make_stride(_0{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto [M, N, K] = cta_shape_mnk;
|
||||
return make_layout(
|
||||
make_shape(make_shape(Int<SFVecSizeM>{}, Int<cute::ceil_div(size<0>(CtaShape_MNK{}), SFVecSizeM)>{}),
|
||||
make_shape(Int<SFVecSizeK>{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{})),
|
||||
strides
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
deduce_layoutSFB() {
|
||||
return LayoutSFB{};
|
||||
}
|
||||
|
||||
template<typename CtaShape_MNK>
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
smem_atom_layoutSFB(CtaShape_MNK cta_shape_mnk) {
|
||||
static_assert(cute::is_static_v<CtaShape_MNK>, "Expect static CTA shape");
|
||||
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
if constexpr (majorSFA == UMMA::Major::MN) {
|
||||
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int<cute::ceil_div(size<1>(CtaShape_MNK{}), SFVecSizeN)>{}));
|
||||
}
|
||||
else {
|
||||
return make_stride(make_stride(_0{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto [M, N, K] = cta_shape_mnk;
|
||||
return make_layout(
|
||||
make_shape(make_shape(Int<SFVecSizeN>{}, Int<cute::ceil_div(size<1>(CtaShape_MNK{}), SFVecSizeN)>{}),
|
||||
make_shape(Int<SFVecSizeK>{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{})),
|
||||
strides
|
||||
);
|
||||
}
|
||||
|
||||
// The following function is provided for user fill dynamic problem size to the layout_SFA.
|
||||
template <class ProblemShape>
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
tile_atom_to_shape_SFA(ProblemShape problem_shape) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
if constexpr (majorSFA == UMMA::Major::MN) {
|
||||
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, SFVecSizeM)));
|
||||
}
|
||||
else {
|
||||
return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto mk_layout = make_layout(
|
||||
make_shape(make_shape(Int<SFVecSizeM>{}, cute::ceil_div(M, SFVecSizeM)),
|
||||
make_shape(Int<SFVecSizeK>{}, cute::ceil_div(K, SFVecSizeK))),
|
||||
strides
|
||||
);
|
||||
|
||||
return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout))));
|
||||
}
|
||||
|
||||
// The following function is provided for user fill dynamic problem size to the layout_SFB.
|
||||
template <class ProblemShape>
|
||||
CUTE_HOST_DEVICE
|
||||
static constexpr auto
|
||||
tile_atom_to_shape_SFB(ProblemShape problem_shape) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
if constexpr (majorSFB == UMMA::Major::MN) {
|
||||
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, SFVecSizeN)));
|
||||
}
|
||||
else {
|
||||
return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto nk_layout = make_layout(
|
||||
make_shape(make_shape(Int<SFVecSizeN>{}, cute::ceil_div(N, SFVecSizeN)),
|
||||
make_shape(Int<SFVecSizeK>{}, cute::ceil_div(K, SFVecSizeK))),
|
||||
strides
|
||||
);
|
||||
|
||||
return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout))));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class MmaTileShape_MNK>
|
||||
constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
||||
return Sm100BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::detail
|
||||
@ -0,0 +1,304 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<
|
||||
int CapacityBytes,
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementScalar,
|
||||
class ScaleShapeMNK,
|
||||
class TileShapeMNK,
|
||||
class MainloopPipelineStorage,
|
||||
class TransformLoadPipelineStorage,
|
||||
class TransformPipelineStorage,
|
||||
int stages
|
||||
>
|
||||
constexpr int
|
||||
sm100_compute_stage_count_or_override_blockwise(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<
|
||||
int CapacityBytes,
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementScalar,
|
||||
class ScaleShapeMNK,
|
||||
class TileShapeMNK,
|
||||
class MainloopPipelineStorage,
|
||||
class TransformLoadPipelineStorage,
|
||||
class TransformPipelineStorage,
|
||||
int stages
|
||||
>
|
||||
constexpr int
|
||||
sm100_compute_stage_count_or_override_blockwise(cute::Int<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template<
|
||||
int CapacityBytes,
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementScalar,
|
||||
class ScaleShapeMNK,
|
||||
class TileShapeMNK,
|
||||
class MainloopPipelineStorage,
|
||||
class TransformLoadPipelineStorage,
|
||||
class TransformPipelineStorage,
|
||||
int carveout_bytes>
|
||||
constexpr int
|
||||
sm100_compute_stage_count_or_override_blockwise(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
// For F8/F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t
|
||||
// For Planar Complex, ElementA/B will be passed in as cutlass::complex<ElementARaw>
|
||||
// Each stage include (CollectiveMma::SharedStorage)
|
||||
// 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage)
|
||||
// 2. one of each of the pipelines
|
||||
constexpr auto pipeline_bytes = sizeof(MainloopPipelineStorage) +
|
||||
sizeof(TransformLoadPipelineStorage) + sizeof(TransformPipelineStorage);
|
||||
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
||||
constexpr auto scale_bits = cute::sizeof_bits_v<ElementScalar>;
|
||||
|
||||
constexpr int stage_bytes =
|
||||
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
|
||||
static_cast<int>(pipeline_bytes);
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATagPair,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTagPair,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATagPair,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTagPair,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
|
||||
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
|
||||
cute::is_tuple_v<GmemLayoutATagPair> && cute::is_tuple_v<GmemLayoutBTagPair> &&
|
||||
// Dense Gemm
|
||||
cute::is_base_of_v<KernelScheduleSm100Blockwise, KernelScheduleType> &&
|
||||
// Alignment check
|
||||
detail::sm1xx_gemm_is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, KernelScheduleType>()>>
|
||||
{
|
||||
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
|
||||
static_assert(detail::check_input_datatypes<ElementA, ElementB>(), "Incorrect input types");
|
||||
|
||||
using GmemLayoutATag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutATagPair{}))>;
|
||||
using GmemLayoutSFATag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutATagPair{}))>;
|
||||
using GmemLayoutBTag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutBTagPair{}))>;
|
||||
using GmemLayoutSFBTag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutBTagPair{}))>;
|
||||
|
||||
static_assert(cute::depth(GmemLayoutSFATag{}) == 2 and cute::depth(GmemLayoutSFBTag{}) == 2,
|
||||
"Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)");
|
||||
static_assert(size<1,0>(GmemLayoutSFATag{}) == size<1, 0>(GmemLayoutSFBTag{}),
|
||||
"SFA and SFB must have equivalent SF vector sizes along K");
|
||||
|
||||
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||
|
||||
// Data type used by MMA instruction
|
||||
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element<ElementA>());
|
||||
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element<ElementB>());
|
||||
|
||||
static constexpr bool is_2sm = cute::is_base_of_v<KernelSchedule2Sm, KernelScheduleType> ||
|
||||
(not cute::is_base_of_v<KernelSchedule1Sm, KernelScheduleType> &&
|
||||
not cute::is_base_of_v<KernelSchedule2Sm, KernelScheduleType> &&
|
||||
cute::is_static_v<ClusterShape_MNK> &&
|
||||
cute::get<0>(ClusterShape_MNK{}) % 2 == 0 );
|
||||
|
||||
static_assert(detail::sm100_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
UmmaMajorA, UmmaMajorB, KernelScheduleType, is_2sm>(),
|
||||
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
|
||||
using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
|
||||
ElementAMma, ElementBMma, ElementAccumulator,
|
||||
decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
|
||||
UmmaMajorA, UmmaMajorB, KernelScheduleType>());
|
||||
|
||||
using ElementAMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
||||
using ElementBMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
||||
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
|
||||
using AtomThrShapeMNK = cute::Shape<decltype(cute::shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
||||
using CtaTileShape_MNK = decltype(cute::shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
|
||||
using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
|
||||
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
|
||||
static_assert(BlockTileA_K{} == BlockTileB_K{}, "Block tile Ks should be equal");
|
||||
|
||||
using SmemShape_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
|
||||
using SmemShape_N = decltype(shape_div(shape<1>(TileShape_MNK{}), shape_div(shape<1>(TileShape_MNK{}), size<1>(TileShape_MNK{}) / size(AtomThrID{}))));
|
||||
using SmemShape_K = decltype(cute::get<2>(TileShape_MNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorA, ElementAMma_SmemAllocType, SmemShape_M, SmemShape_K>());
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorB, ElementBMma_SmemAllocType, SmemShape_N, SmemShape_K>());
|
||||
static constexpr uint32_t TotalTmemRows = 128;
|
||||
static constexpr uint32_t Sm100TmemCapacityColumns = 512;
|
||||
static constexpr uint32_t TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns;
|
||||
static constexpr uint32_t AccumulatorPipelineStageCount = (is_2sm || (!is_2sm && size(shape<0,0>(MmaShapeA_MK{}) > 64))) ?
|
||||
TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{}))
|
||||
: (Sm100TmemCapacityColumns / cute::size<1>(CtaTileShape_MNK{})) * 2; // 1SM MMA_M = 64 case
|
||||
static_assert(AccumulatorPipelineStageCount > 0, "Accumulator pipeline stage count must be positive. This error probably means that TileShape_MNK and/or TiledMma::ThrLayoutVMNK are wrong.");
|
||||
|
||||
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||
// Grouped GEMM (where Stride type is Stride*) does not use CLC based scheduler.
|
||||
// SchedulerPipelineStageCount could be set to zero for Grouped GEMM, but we shouldn't define CLC Pipeline's barrier arrays of size zero.
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v<InternalStrideA, StrideA> ? (AccumulatorPipelineStageCount + 1) : 1;
|
||||
|
||||
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
|
||||
ClusterShape_MNK,
|
||||
AccumulatorPipelineStageCount,
|
||||
SchedulerPipelineStageCount,
|
||||
detail::CLCResponseSize,
|
||||
false
|
||||
>::KernelSmemCarveout;
|
||||
// Reduce SMEM capacity available for buffers considering barrier allocations.
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
||||
using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage;
|
||||
using TransformLoadPipelineStorage = typename cutlass::PipelineAsync<1>::SharedStorage;
|
||||
using TransformPipelineStorage = typename cutlass::PipelineUmmaAsync<1>::SharedStorage;
|
||||
|
||||
static constexpr int ScaleGranularityM = size<0,0>(GmemLayoutSFATag{});
|
||||
static constexpr int ScaleGranularityN = size<0,0>(GmemLayoutSFBTag{});
|
||||
static constexpr int ScaleGranularityK = size<1,0>(GmemLayoutSFBTag{});
|
||||
|
||||
static_assert(size<0>(CtaTileShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape");
|
||||
static_assert(size<1>(CtaTileShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape");
|
||||
static_assert(size<2>(CtaTileShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape");
|
||||
|
||||
using BlockTileScale_M = Int<size<0>(TileShape_MNK{}) / ScaleGranularityM>;
|
||||
using BlockTileScale_N = Int<size<1>(TileShape_MNK{}) / ScaleGranularityN>;
|
||||
using BlockTileScale_K = Int<size<2>(TileShape_MNK{}) / ScaleGranularityK>;
|
||||
|
||||
using ScaleTileShape = cute::Shape<BlockTileScale_M, BlockTileScale_N, BlockTileScale_K>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockwise<
|
||||
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType,
|
||||
ElementAccumulator, ScaleTileShape, SmemTileShape, MainloopPipelineStorage,
|
||||
TransformLoadPipelineStorage, TransformPipelineStorage>(StageCountType{});
|
||||
static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales.");
|
||||
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling<
|
||||
PipelineStages,
|
||||
SchedulerPipelineStageCount,
|
||||
AccumulatorPipelineStageCount,
|
||||
ClusterShape_MNK>;
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cute::tuple<cutlass::gemm::TagToStrideA_t<GmemLayoutATag>, cutlass::gemm::TagToStrideA_t<GmemLayoutSFATag>>,
|
||||
ElementB,
|
||||
cute::tuple<cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, cutlass::gemm::TagToStrideB_t<GmemLayoutSFBTag>>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
void,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1046,8 +1046,7 @@ template <
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM_,
|
||||
int ScaleGranularityN_
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
@ -1062,11 +1061,16 @@ struct CollectiveBuilder<
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
cute::is_same_v<decltype(KernelScheduleType::ScaleGranularityM), decltype(KernelScheduleType::ScaleGranularityN)> and
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()
|
||||
>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
|
||||
|
||||
static constexpr auto ScaleGranularityM_ = KernelScheduleType::ScaleGranularityM;
|
||||
static constexpr auto ScaleGranularityN_ = KernelScheduleType::ScaleGranularityN;
|
||||
static constexpr auto ScalePromotionInterval_ = KernelScheduleType::ScalePromotionInterval;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@ -1076,12 +1080,12 @@ struct CollectiveBuilder<
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsArrayOfPointersGemm = (
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelScheduleType> ||
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelScheduleType>);
|
||||
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
@ -1091,10 +1095,9 @@ struct CollectiveBuilder<
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>>;
|
||||
static constexpr bool IsCooperative = cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, KernelScheduleType> ||
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelScheduleType>;
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
@ -1121,7 +1124,9 @@ struct CollectiveBuilder<
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_>;
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
@ -43,6 +43,7 @@
|
||||
#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl"
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -50,6 +50,7 @@
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp"
|
||||
@ -59,5 +60,6 @@
|
||||
|
||||
#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -223,6 +223,30 @@ public:
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator without checking the counter.
|
||||
CUTLASS_DEVICE
|
||||
void scale(ElementAccumulator const &scale) {
|
||||
scale_core(scale);
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
scale_core(scale);
|
||||
}
|
||||
|
||||
template <
|
||||
class EngineScaleA,
|
||||
class LayoutScaleA,
|
||||
class EngineScaleB,
|
||||
class LayoutScaleB>
|
||||
CUTLASS_DEVICE
|
||||
void scale(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
|
||||
scale_core(scaleA, scaleB);
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -204,6 +204,8 @@ public:
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static constexpr int NumProducerThreadEvents = 1;
|
||||
|
||||
using SmemLayoutAtomScale = Layout<Shape<decltype(cute::shape<0>(SwappedSmemLayoutAtomA{})), cute::Int<1>>>;
|
||||
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{})));
|
||||
|
||||
@ -1354,6 +1356,18 @@ public:
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in tensormaps_fence_acquire.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class InputTensors, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
InputTensors
|
||||
tensors_perform_update(
|
||||
InputTensors const& input_tensors,
|
||||
[[maybe_unused]] Params const& mainloop_params,
|
||||
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
|
||||
[[maybe_unused]] int32_t next_batch) {
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -116,6 +116,9 @@ struct CollectiveMma<
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
|
||||
static constexpr int NumProducerThreadEvents = 1;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
@ -749,6 +752,16 @@ struct CollectiveMma<
|
||||
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
|
||||
}
|
||||
|
||||
template <class InputTensors, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
InputTensors
|
||||
tensors_perform_update(
|
||||
InputTensors const& input_tensors,
|
||||
[[maybe_unused]] Params const& mainloop_params,
|
||||
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
|
||||
[[maybe_unused]] int32_t next_batch) {
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
@ -759,6 +759,18 @@ struct CollectiveMma<
|
||||
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
|
||||
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
|
||||
}
|
||||
|
||||
template <class InputTensors, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
InputTensors
|
||||
tensors_perform_update(
|
||||
InputTensors const& input_tensors,
|
||||
[[maybe_unused]] Params const& mainloop_params,
|
||||
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
|
||||
[[maybe_unused]] int32_t next_batch) {
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -59,6 +59,7 @@ template <
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
int ScaleGranularityN_,
|
||||
int ScalePromotionInterval_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
@ -74,7 +75,7 @@ template <
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
@ -93,7 +94,7 @@ struct CollectiveMma<
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>;
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
@ -122,6 +123,8 @@ struct CollectiveMma<
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
|
||||
static constexpr int ScalePromotionInterval = ScalePromotionInterval_;
|
||||
static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4.");
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
@ -281,7 +284,9 @@ struct CollectiveMma<
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
|
||||
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
@ -481,6 +486,38 @@ struct CollectiveMma<
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class EngineAccum,
|
||||
class LayoutAccum,
|
||||
class ScaleFactor
|
||||
>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(GmmaFP8Accumulation<EngineAccum, LayoutAccum>& accumulation, ScaleFactor scaleFactor) {
|
||||
if constexpr (ScalePromotionInterval != 4) {
|
||||
accumulation.scale_if_needed(scaleFactor);
|
||||
}
|
||||
else {
|
||||
// avoid unnecessary tests when granularity is the finnest
|
||||
accumulation.scale(scaleFactor);
|
||||
}
|
||||
}
|
||||
template<
|
||||
class EngineAccum,
|
||||
class LayoutAccum,
|
||||
class ScaleFactor1,
|
||||
class ScaleFactor2
|
||||
>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(GmmaFP8Accumulation<EngineAccum, LayoutAccum>& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) {
|
||||
if constexpr (ScalePromotionInterval != 4) {
|
||||
accumulation.scale_if_needed(scaleFactor1, scaleFactor2);
|
||||
}
|
||||
else {
|
||||
// avoid unnecessary tests when granularity is the finnest
|
||||
accumulation.scale(scaleFactor1, scaleFactor2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
@ -575,7 +612,7 @@ struct CollectiveMma<
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
@ -584,7 +621,13 @@ struct CollectiveMma<
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
if constexpr (ScalePromotionInterval != 4) {
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Always zero out the accumulator for finest granularity
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
@ -624,16 +667,16 @@ struct CollectiveMma<
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_if_needed(scale_ab);
|
||||
scale_if_needed(accumulation, scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleBViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
|
||||
++smem_pipe_read;
|
||||
@ -677,7 +720,13 @@ struct CollectiveMma<
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
if constexpr (ScalePromotionInterval != 4) {
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Always zero out the accumulator for finest granularity
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
@ -699,16 +748,16 @@ struct CollectiveMma<
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_if_needed(scale_ab);
|
||||
scale_if_needed(accumulation, scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleBViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
@ -718,18 +767,21 @@ struct CollectiveMma<
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_residue_if_needed(scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
if constexpr (ScalePromotionInterval != 4) {
|
||||
// residues only exists when granularity is not the finnest
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
||||
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
|
||||
accumulation.scale_residue_if_needed(scale_ab);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
|
||||
}
|
||||
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
@ -120,10 +120,38 @@ template<
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM = 0,
|
||||
int ScaleGranularityN = 0
|
||||
int ScaleGranularityM_ = 0,
|
||||
int ScaleGranularityN_ = 0,
|
||||
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
|
||||
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
|
||||
// in the reduction dimension. i.e for FP8 kernels, it is
|
||||
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
|
||||
int ScalePromotionInterval_ = 4
|
||||
|
||||
>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative {
|
||||
constexpr static int ScaleGranularityM = ScaleGranularityM_;
|
||||
constexpr static int ScaleGranularityN = ScaleGranularityN_;
|
||||
constexpr static int ScalePromotionInterval = ScalePromotionInterval_;
|
||||
};
|
||||
|
||||
template<
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM_,
|
||||
int ScaleGranularityN_,
|
||||
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
|
||||
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
|
||||
// in the reduction dimension. i.e for FP8 kernels, it is
|
||||
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
|
||||
int ScalePromotionInterval_ = 4
|
||||
>
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative {
|
||||
constexpr static int ScaleGranularityM = ScaleGranularityM_;
|
||||
constexpr static int ScaleGranularityN = ScaleGranularityN_;
|
||||
constexpr static int ScalePromotionInterval = ScalePromotionInterval_;
|
||||
};
|
||||
|
||||
// Policies to opt into mixed type GEMMs
|
||||
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
||||
@ -310,12 +338,17 @@ template<
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM = 0,
|
||||
int ScaleGranularityN = 0
|
||||
int ScaleGranularityN = 0,
|
||||
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
|
||||
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
|
||||
// in the reduction dimension. i.e for FP8 kernels, it is
|
||||
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
|
||||
int ScalePromotionInterval = 4
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN, ScalePromotionInterval>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
@ -327,6 +360,7 @@ template<
|
||||
>
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
||||
constexpr static int Stages = Stages_;
|
||||
constexpr static int PipelineAsyncMmaStages = 1;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
@ -391,6 +425,26 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
|
||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
|
||||
};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
||||
// For FP8 kernels with Block Scaling
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
|
||||
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
|
||||
int ScaleGranularityM = 0,
|
||||
int ScaleGranularityN = 0,
|
||||
int ScalePromotionInterval = 4
|
||||
>
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
|
||||
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<KernelSchedule, KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
@ -411,6 +465,14 @@ struct KernelTmaWarpSpecializedBlockScaledSm100 final {
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_
|
||||
>
|
||||
struct KernelTmaWarpSpecializedMmaTransformSm100 final {
|
||||
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
|
||||
// InputTransform GEMM
|
||||
@ -484,6 +546,13 @@ struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {};
|
||||
struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {};
|
||||
struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SM100 Blockwise GEMM Dispatch Policies
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct KernelScheduleSm100Blockwise : KernelScheduleSm100 {};
|
||||
struct KernelTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100Blockwise {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SM100 Planar Complex GEMM Dispatch Policies
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -530,6 +599,9 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSch
|
||||
struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
|
||||
struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { };
|
||||
struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { };
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// BlockScaled Dense GEMM + (Ptr Array or Group GEMM)
|
||||
struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {};
|
||||
struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {};
|
||||
@ -544,8 +616,6 @@ struct KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, K
|
||||
struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
|
||||
struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
|
||||
|
||||
|
||||
|
||||
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
||||
template<
|
||||
int Stages_,
|
||||
@ -561,7 +631,20 @@ struct MainloopSm100TmaUmmaWarpSpecialized {
|
||||
constexpr static bool IsOverlappingAccum = false;
|
||||
};
|
||||
|
||||
|
||||
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
||||
template<
|
||||
int Stages_,
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>
|
||||
>
|
||||
struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm100;
|
||||
using Schedule = KernelTmaWarpSpecializedMmaTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
||||
constexpr static bool IsOverlappingAccum = false;
|
||||
};
|
||||
|
||||
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
||||
template<
|
||||
|
||||
@ -65,6 +65,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
|
||||
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -128,10 +128,11 @@ public:
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{}));
|
||||
static constexpr uint32_t NumMmaThreads = size(TiledMma{});
|
||||
static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
@ -434,7 +435,8 @@ public:
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = size(TiledMma{});
|
||||
mainloop_pipeline_params.num_consumers = NumMmaThreads;
|
||||
mainloop_pipeline_params.num_producers = NumProducerThreads;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
@ -575,6 +577,7 @@ public:
|
||||
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
if (did_batch_change) {
|
||||
load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch);
|
||||
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
|
||||
}
|
||||
|
||||
|
||||
@ -131,6 +131,7 @@ public:
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
@ -443,6 +444,7 @@ public:
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.num_producers = NumProducerThreads;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
@ -607,6 +609,7 @@ public:
|
||||
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
if (did_batch_change) {
|
||||
load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch);
|
||||
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user