v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -483,18 +483,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -566,17 +566,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -103,11 +103,10 @@
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "mixed_dtype_utils.hpp"
#include "packed_scale.hpp"
#include "reorder_utils.hpp"
using namespace cute;
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
int constexpr NumShuffleAtoms = 1;
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
using ElementScale = MmaType;
using ElementZero = ElementScale;
@ -438,14 +437,15 @@ void initialize(Options const& options) {
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
if (options.shuffle) {
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
cutlass::reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
print("Quantized tensor layout: ");
print(layout_B_reordered);
@ -613,17 +613,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -107,11 +107,10 @@
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "mixed_dtype_utils.hpp"
#include "packed_scale.hpp"
#include "reorder_utils.hpp"
using namespace cute;
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
using ElementScale = MmaType;
using ElementZero = ElementScale; // only for verify
@ -349,10 +348,10 @@ void initialize(Options const& options) {
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
unify_quant_encoding(block_B, block_B_modified);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_packed_scale(block_scale, block_scale_packed);
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
initialize_zero(block_zero, options);
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
@ -360,12 +359,13 @@ void initialize(Options const& options) {
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
if (options.shuffle) {
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
cutlass::reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
print("Quantized tensor layout: ");
print(layout_B_reordered);
@ -518,17 +518,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -100,6 +100,7 @@
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "mixed_dtype_utils.hpp"
@ -322,9 +323,10 @@ void initialize(MixedDtypeOptions const& options) {
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
}
/// Populates a Gemm::Arguments structure from the given commandline options
@ -483,17 +485,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -60,8 +60,8 @@ struct MixedDtypeOptions {
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 1000;
int warmup = 1000;
int iterations = 100;
int warmup = 10;
int mode = 1;
int m = 5120, n = 4096, k = 4096;
int g = 128;
@ -228,22 +228,18 @@ bool initialize_scale(
MixedDtypeOptions const& options,
uint64_t seed = 2023) {
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(1.0f));
block.copy_from_host(stage.data());
}
else {
// If no scales, initialize with 1 so we can use the same kernel to dequantize the data
float scope_max = 1.0f, scope_min = 1.0f;
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
scope_max = max_dequant_val / elt_max_f;
scope_min = min_dequant_val / elt_max_f;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
@ -253,139 +249,14 @@ bool initialize_zero(
MixedDtypeOptions const& options,
uint64_t seed = 2023) {
// If no bias, initialize with 0 so we can use the same kernel to dequantize the data
float scope_max = 0.0f, scope_min = 0.0f;
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
} else {
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
scope_max = 2.0f;
scope_min = -2.0f;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
/// Dequantize the weights for verification
template <class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleBroadCastLayout,
class ThrLayout>
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleBroadCastLayout const broadcasted_scale_layout,
ThrLayout thr_layout) {
using namespace cute;
// Represent the full tensors to gmem elements.
// These are expected to have shape [MN, K, L]
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
auto init_quantized_iterator = [&]() {
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
return cute::make_gmem_ptr(q_buffer);
} else {
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
}
};
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
// It is expected that K % G == 0
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
// Assign 1 thread per element in the thread block
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
// Tile across the block
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
// Make a fragment of registers to hold gmem loads
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
for (int ii = 0; ii < num_iters; ++ii) {
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
if (thread_offset < cute::size<0>(operand_layout)) {
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
}
}
}
template <class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleLayout>
void dequantize_weight(DequantizedElement* dq_buffer,
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleLayout const scale_layout,
int const group_size) {
using namespace cute;
constexpr int tpb = 128;
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
const auto num_rows = get<0>(shape(operand_layout));
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
if (num_rows != size<0>(scale_layout)) {
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
<< std::endl;
exit(-1);
}
const auto scale_stride0 = get<0>(stride(scale_layout));
const auto scale_stride1 = get<1>(stride(scale_layout));
const auto scale_stride2 = get<2>(stride(scale_layout));
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
const auto blocks_x = gemm_k;
const auto blocks_y = batches;
dim3 blocks(blocks_x, blocks_y, 1);
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@ -1,211 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstdint>
#include "cutlass/util/device_memory.h"
#include "cutlass/integer_subbyte.h"
#include "cutlass/float8.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cute/tensor.hpp"
#include "cute/util/type_traits.hpp"
namespace cutlass
{
template<typename T>
class packed_scale_t {
public:
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
cute::is_same_v<T, cutlass::uint8_t> ||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
cute::is_same_v<T, cutlass::float_e5m2_t>,
"only 8 bit arithmetic types are supported.");
CUTLASS_HOST_DEVICE
explicit packed_scale_t(T val) {
if constexpr (!cute::is_unsigned_v<T>) {
// Only pack negative values. The positive values are generated in flight in the mainloop.
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
}
else {
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
}
}
CUTLASS_HOST_DEVICE
packed_scale_t() = default;
CUTLASS_HOST_DEVICE
explicit operator float() const {
return float(get());
}
CUTLASS_HOST_DEVICE
bool operator==(packed_scale_t const& rhs) const {
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
}
CUTLASS_HOST_DEVICE
bool operator!=(packed_scale_t const& rhs) const {
return !(*this == rhs);
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() + rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() - rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() * rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() / rhs.get());
}
private:
using Storage = uint32_t;
using Stage = uint8_t;
Storage storage[2] {};
CUTLASS_HOST_DEVICE
static Storage pack4(T c1, T c2, T c3, T c4) {
Storage result = 0;
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
return result;
}
CUTLASS_HOST_DEVICE
T get() const {
auto stage = static_cast<Stage>(storage[0] >> 8);
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
CUTLASS_HOST_DEVICE
T get(int idx) const {
Stage stage;
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
};
}
/// Helpers to initialize scale lookup table
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
// Here the encodings of positive values and negative values are unified (except for the sign bit).
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
bool unify_quant_encoding(
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
using StorageType = cutlass::int4b_t::Storage;
if (block_in.size() != block_out.size()) {
std::cerr << "block_in and block_out must have same size.\n";
return false;
}
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
std::vector<StorageType> data(block_in.size() / pack);
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
for (auto&& d : data) {
StorageType out = 0;
StorageType mask = 0x0f;
for (int i = 0; i < pack; ++i) {
cutlass::int4b_t curr;
curr.storage = (d >> (i * 4)) & 0x0f;
switch (curr) {
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
default: break;
}
out |= (curr.storage << (4 * i)) & mask;
mask <<= 4;
}
d = out;
}
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
return true;
}
template <class ElementScale>
bool initialize_packed_scale(
cutlass::DeviceAllocation<ElementScale> const& block_in,
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
std::vector<ElementScale> data_in(block_in.size());
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
try {
block_in.copy_to_host(data_in.data());
} catch (cutlass::cuda_exception const& e)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
for (size_t i = 0; i < block_in.size(); ++i)
{
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
}
try {
block_out.copy_from_host(data_out.data());
} catch (cutlass::cuda_exception const& e)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
return true;
}

View File

@ -1,162 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/mma_sm90.hpp"
#include "cutlass/util/device_memory.h"
// Given a type of MMA instruction, compute a memory reordering atom that places all values
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
// In addition, we can reorder the values across several MMA instructions to get even wider
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
// more optimal conversion instruction sequences (ValLayout parameter).
template<class ElementMma,
class AtomLayout = cute::Layout<cute::_1>,
class ValLayout = cute::Layout<cute::_1>>
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
{
using namespace cute;
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
// 1. Choose an MMA atom to access TV layout and MN shape
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
using MmaTraits = MMA_Traits<MmaAtom>;
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
auto tv_layout_mma = typename MmaTraits::ALayout{};
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
// Note: this assumes A is partitioned between warps along M mode
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
return layout_atom;
}
template<class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
__global__ void reorder_tensor_kernel(
cute::Tensor<EngineSrc, LayoutSrc> S,
cute::Tensor<EngineDst, LayoutDst> D,
TiledCopy tiled_copy)
{
using namespace cute;
using T = typename EngineDst::value_type;
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
Tensor tS = thread_copy.partition_S(gS);
Tensor tD = thread_copy.partition_D(gD);
copy(tiled_copy, tS, tD);
}
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
void reorder_tensor(
cute::Tensor<EngineSrc, LayoutSrc> S,
cute::Tensor<EngineDst, LayoutDst> D)
{
using namespace cute;
using T = typename EngineDst::value_type;
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
auto has_major_mode = [](auto s) {
return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; });
};
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
"Could not find stride-1 mode in destination layout");
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
// Make a tiled copy with a simple row-major thread order and above layout
int constexpr NumThreads = 128;
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
using TileShape = Shape<_16>;
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
CUDA_CHECK(cudaDeviceSynchronize());
}
// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T const* src,
LayoutSrc const& layout_src,
T * dst,
LayoutDst const& layout_dst)
{
using namespace cute;
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
}
// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T * data,
LayoutSrc const& layout_src,
LayoutDst const& layout_dst)
{
using namespace cute;
cutlass::DeviceAllocation<T> temp(size(layout_src));
reorder_tensor(data, layout_src, temp.get(), layout_dst);
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
}

View File

@ -513,17 +513,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -731,17 +731,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -768,16 +768,26 @@ int main(int argc, char const** argv) {
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) ||
(props.major != 8 && props.minor != 9)) {
bool satisfied;
if (props.major < 10) {
// Pre-Blackwell
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4);
satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9);
}
else {
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8);
}
if (!satisfied) {
//
// This example requires an NVIDIA Ada-architecture GPU.
// This example requires an NVIDIA GPU with compute capability 8.9 or greater.
//
std::cout
<< "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture "
<< "and CUDA toolkit version 12.4 or later.\n";
<< "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater "
<< "and CUDA toolkit version 12.4 or later"
<< " (12.8 or later needed for SM100+)"
<< std::endl;
return 0;
}

View File

@ -504,17 +504,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -570,18 +570,13 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//

View File

@ -469,18 +469,12 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//
// Parse options

View File

@ -31,28 +31,20 @@
/*! \file
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster.
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
--m=2816 --n=3072 --k=16384 \
--save_aux=false --save_amax=false \

View File

@ -34,4 +34,4 @@ cutlass_example_add_executable(
cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
)
)

View File

@ -191,7 +191,7 @@ void gett_mainloop(
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;

View File

@ -0,0 +1,818 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
NOTE: Write docu
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <numeric>
#include <typeinfo>
#include <float.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "grouped_mixed_dtype_utils.hpp"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
// using ValueShuffle = Layout<_1>; // no value reordering
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
int constexpr NumShuffleAtoms = 1;
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
using ElementZero = cutlass::bfloat16_t;
using ElementScale = cutlass::bfloat16_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementB, LayoutB_Transpose *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopConvertOnly,
CollectiveEpilogue
>;
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementB, LayoutB_Reordered *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopConvertOnlyShuffled,
CollectiveEpilogue
>;
using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnlyShuffled>;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopScaleOnlyShuffled,
CollectiveEpilogue
>;
using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnlyShuffled>;
using StrideC = typename GemmKernelConvertOnly::InternalStrideC;
using StrideD = typename GemmKernelConvertOnly::InternalStrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_B_dq;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_scale;
std::vector<int64_t> offset_zero;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<StrideC_ref> stride_C_host_ref;
std::vector<StrideD_ref> stride_D_host_ref;
std::vector<StrideS> stride_S_host;
std::vector<StrideS_ref> stride_S_host_ref;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed = 2020;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<MmaType> block_A;
cutlass::DeviceAllocation<QuantType> block_B;
cutlass::DeviceAllocation<MmaType> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const MmaType *> ptr_A;
cutlass::DeviceAllocation<const QuantType *> ptr_B;
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
cutlass::DeviceAllocation<StrideS> stride_S;
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options : GroupedMixedDtypeOptions<QuantType> {
using Base = GroupedMixedDtypeOptions<QuantType>;
bool shuffle = true;
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("shuffle", shuffle);
this->Base::parse(argc, args);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "69_hopper_int4_bf16_grouped_gemm\n\n"
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "69_hopper_int4_bf16_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
return out;
}
};
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocates device-side data
void allocate(Options const& options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_B_dq = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_scale = 0;
int64_t total_elements_zero = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
offset_B_dq.push_back(total_elements_B_dq);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_scale.push_back(total_elements_scale);
offset_zero.push_back(total_elements_zero);
int64_t elements_A = M * K;
int64_t elements_B = K * N ;
int64_t elements_B_dq = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_scale = scale_k * N;
int64_t elements_zero = scale_k * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_B_dq += elements_B_dq;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_scale += elements_scale;
total_elements_zero += elements_zero;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_B_dq.reset(total_elements_B_dq);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_scale.reset(total_elements_scale);
block_zero.reset(total_elements_zero);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<MmaType *> ptr_A_host(options.groups);
std::vector<QuantType *> ptr_B_host(options.groups);
std::vector<MmaType *> ptr_B_dq_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementScale *> ptr_scale_host(options.groups);
std::vector<ElementZero *> ptr_zero_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_B_dq.reset(options.groups);
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_scale.reset(options.groups);
ptr_scale.copy_from_host(ptr_scale_host.data());
ptr_zero.reset(options.groups);
ptr_zero.copy_from_host(ptr_zero_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
stride_C_ref.reset(options.groups);
stride_C_ref.copy_from_host(stride_C_host_ref.data());
stride_D_ref.reset(options.groups);
stride_D_ref.copy_from_host(stride_D_host_ref.data());
stride_S_ref.reset(options.groups);
stride_S_ref.copy_from_host(stride_S_host_ref.data());
stride_S.reset(options.groups);
stride_S.copy_from_host(stride_S_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
for (int32_t i = 0; i < options.groups; ++i) {
const int scale_k = 1;
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
}
problem_sizes.reset(options.groups);
if (options.shuffle) {
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
cutlass::reorder_tensor(block_B.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
if (i == 0) {
print("Quantized tensor layout: ");
print(layout_B_reordered_host[0]);
print("\n");
}
}
layout_B_reordered.reset(options.groups);
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
}
// Reverse MN -> NM for SwapAB
for (int32_t i = 0; i < options.groups; ++i) {
auto [M, N, K] = options.problem_sizes_host[i];
options.problem_sizes_host[i] = make_tuple(N, M, K);
}
problem_sizes.copy_from_host(options.problem_sizes_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Gemm>
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
{
using Args = typename Gemm::Arguments;
auto&& dB = [&]() {
// NOTE: add GemmScaleWithZeroPointShuffled
if constexpr (cute::is_same_v<Gemm, GemmConvertOnlyShuffled> ||
cute::is_same_v<Gemm, GemmScaleOnlyShuffled>) {
// offline swizzling is enabled.
return layout_B_reordered.get();
}
else {
return stride_B.get();
}
}();
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
exit(-1);
}
return arguments;
}
bool verify(Options const& options) {
bool passed = true;
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
ScheduleRef
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
const ElementD epsilon(1e-2f);
const ElementD non_zero_floor(1e-4f);
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
if (M == 0) {
continue;
}
else {
StrideA_verif stride_A_verif;
StrideB_verif stride_B_verif;
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
//
// Compute reference output
//
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
}
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
std::cout << "We passed all checks\n";
// Check if output from CUTLASS kernel and reference kernel are equal or not
MixedDtypeResult result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
if (!result.passed) {
exit(-1);
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
std::cout << "Running in no scale mode." << std::endl;
if (options.shuffle) {
std::cout << "Offline shuffle enabled." << std::endl;
run<GemmConvertOnlyShuffled>(options, false);
} else {
std::cout << "Offline shuffle disabled." << std::endl;
run<GemmConvertOnly>(options, false);
}
}
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
std::cout << "Running in per-column scale mode." << std::endl;
if (options.shuffle) {
std::cout << "Offline shuffle enabled." << std::endl;
run<GemmScaleOnlyShuffled>(options, false);
} else {
std::cout << "Offline shuffle disabled." << std::endl;
run<GemmScaleOnly>(options, false);
}
}
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,753 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
NOTE: Write docu
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <numeric>
#include <typeinfo>
#include <float.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "grouped_mixed_dtype_utils.hpp"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
using ElementZero = cutlass::float_e4m3_t;
using ElementScale = cutlass::float_e4m3_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
EpilogueSchedule
>::CollectiveOp;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopShuffled,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_B_dq;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_scale;
std::vector<int64_t> offset_zero;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<StrideC_ref> stride_C_host_ref;
std::vector<StrideD_ref> stride_D_host_ref;
std::vector<StrideS> stride_S_host;
std::vector<StrideS_ref> stride_S_host_ref;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed = 2020;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<MmaType> block_A;
cutlass::DeviceAllocation<QuantType> block_B;
cutlass::DeviceAllocation<ElementB> block_B_modified;
cutlass::DeviceAllocation<MmaType> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const MmaType *> ptr_A;
cutlass::DeviceAllocation<const QuantType *> ptr_B;
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
cutlass::DeviceAllocation<const cutlass::Array<ElementScale, 8> *> ptr_scale_packed;
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
cutlass::DeviceAllocation<StrideS> stride_S;
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options : GroupedMixedDtypeOptions<QuantType> {
using Base = GroupedMixedDtypeOptions<QuantType>;
bool shuffle = true;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("shuffle", shuffle);
this->Base::parse(argc, args);
mode = 1; // override the mode value to always be scale only mode
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "69_hopper_int4_fp8_grouped_gemm\n\n"
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --c=<int> The size of each chunk for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "69_hopper_int4_fp8_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
return out;
}
};
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
// Here the encodings of positive values and negative values are unified (except for the sign bit).
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
/// Allocates device-side data
void allocate(Options const& options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_B_dq = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_scale = 0;
int64_t total_elements_zero = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
offset_B_dq.push_back(total_elements_B_dq);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_scale.push_back(total_elements_scale);
offset_zero.push_back(total_elements_zero);
int64_t elements_A = M * K;
int64_t elements_B = K * N ;
int64_t elements_B_dq = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_scale = scale_k * N;
int64_t elements_zero = scale_k * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_B_dq += elements_B_dq;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_scale += elements_scale;
total_elements_zero += elements_zero;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_B_modified.reset(total_elements_B);
block_B_dq.reset(total_elements_B_dq);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_scale.reset(total_elements_scale);
block_scale_packed.reset(total_elements_scale);
block_zero.reset(total_elements_zero);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options& options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<MmaType *> ptr_A_host(options.groups);
std::vector<QuantType *> ptr_B_host(options.groups);
std::vector<MmaType *> ptr_B_dq_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<cutlass::Array<ElementScale, 8> *> ptr_scale_packed_host(options.groups);
std::vector<ElementZero *> ptr_zero_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B_modified.get() + offset_B.at(i);
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_scale_packed_host.at(i) = block_scale_packed.get() + offset_scale.at(i);
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_B_dq.reset(options.groups);
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_scale_packed.reset(options.groups);
ptr_scale_packed.copy_from_host(ptr_scale_packed_host.data());
ptr_zero.reset(options.groups);
ptr_zero.copy_from_host(ptr_zero_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
stride_C_ref.reset(options.groups);
stride_C_ref.copy_from_host(stride_C_host_ref.data());
stride_D_ref.reset(options.groups);
stride_D_ref.copy_from_host(stride_D_host_ref.data());
stride_S_ref.reset(options.groups);
stride_S_ref.copy_from_host(stride_S_host_ref.data());
stride_S.reset(options.groups);
stride_S.copy_from_host(stride_S_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
initialize_zero(block_zero, options);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
problem_sizes.reset(options.groups);
if (options.shuffle) {
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
cutlass::reorder_tensor(block_B_modified.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
if (i == 0) {
print("Quantized tensor layout: ");
print(layout_B_reordered_host[0]);
print("\n");
}
}
layout_B_reordered.reset(options.groups);
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
}
// Reverse MN -> NM for SwapAB
for (int32_t i = 0; i < options.groups; ++i) {
auto [M, N, K] = options.problem_sizes_host[i];
options.problem_sizes_host[i] = make_tuple(N, M, K);
}
problem_sizes.copy_from_host(options.problem_sizes_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Gemm>
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
{
using Args = typename Gemm::Arguments;
auto&& dB = [&]() {
if constexpr (cute::is_same_v<Gemm, GemmShuffled>) { // offline swizzling is enabled.
return layout_B_reordered.get();
}
else {
return stride_B.get();
}
}();
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
return arguments;
}
bool verify(Options const& options) {
bool passed = true;
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
ScheduleRef
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
const ElementD epsilon(1e-2f);
const ElementD non_zero_floor(1e-4f);
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
if (M == 0) {
continue;
}
else {
StrideA_verif stride_A_verif;
StrideB_verif stride_B_verif;
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
//
// Compute reference output
//
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
}
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
std::cout << "We passed all checks\n";
// Check if output from CUTLASS kernel and reference kernel are equal or not
MixedDtypeResult result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
if (!result.passed) {
exit(-1);
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
std::cout << "Running in per-column scale mode." << std::endl;
if (options.shuffle) {
std::cout << "Offline shuffle enabled." << std::endl;
run<GemmShuffled>(options, false);
} else {
std::cout << "Offline shuffle disabled." << std::endl;
run<GemmScaleOnly>(options, false);
}
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,678 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
NOTE: Write docu
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <numeric>
#include <typeinfo>
#include <float.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "helper.h"
#include "grouped_mixed_dtype_utils.hpp"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::float_e5m2_t;
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
// A matrix configuration
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementZero = cutlass::bfloat16_t;
using ElementScale = cutlass::bfloat16_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementB, LayoutB_Transpose *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopConvertOnly,
CollectiveEpilogue
>;
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
ElementA, LayoutA_Transpose *, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = typename GemmConvertOnly::GemmKernel::InternalStrideA;
using StrideB = typename GemmConvertOnly::GemmKernel::InternalStrideB;
using StrideC = typename GemmConvertOnly::GemmKernel::InternalStrideC;
using StrideD = typename GemmConvertOnly::GemmKernel::InternalStrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_B_dq;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_scale;
std::vector<int64_t> offset_zero;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<StrideC_ref> stride_C_host_ref;
std::vector<StrideD_ref> stride_D_host_ref;
std::vector<StrideS> stride_S_host;
std::vector<StrideS_ref> stride_S_host_ref;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed = 2020;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<MmaType> block_A;
cutlass::DeviceAllocation<QuantType> block_B;
cutlass::DeviceAllocation<MmaType> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const MmaType *> ptr_A;
cutlass::DeviceAllocation<const QuantType *> ptr_B;
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
cutlass::DeviceAllocation<StrideS> stride_S;
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using Options = GroupedMixedDtypeOptions<QuantType>;
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocates device-side data
void allocate(Options const& options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_B_dq = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_scale = 0;
int64_t total_elements_zero = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
offset_B_dq.push_back(total_elements_B_dq);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_scale.push_back(total_elements_scale);
offset_zero.push_back(total_elements_zero);
int64_t elements_A = M * K;
int64_t elements_B = K * N ;
int64_t elements_B_dq = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_scale = scale_k * N;
int64_t elements_zero = scale_k * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_B_dq += elements_B_dq;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_scale += elements_scale;
total_elements_zero += elements_zero;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_B_dq.reset(total_elements_B_dq);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_scale.reset(total_elements_scale);
block_zero.reset(total_elements_zero);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<MmaType *> ptr_A_host(options.groups);
std::vector<QuantType *> ptr_B_host(options.groups);
std::vector<MmaType *> ptr_B_dq_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementScale *> ptr_scale_host(options.groups);
std::vector<ElementZero *> ptr_zero_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_B_dq.reset(options.groups);
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_scale.reset(options.groups);
ptr_scale.copy_from_host(ptr_scale_host.data());
ptr_zero.reset(options.groups);
ptr_zero.copy_from_host(ptr_zero_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
stride_C_ref.reset(options.groups);
stride_C_ref.copy_from_host(stride_C_host_ref.data());
stride_D_ref.reset(options.groups);
stride_D_ref.copy_from_host(stride_D_host_ref.data());
stride_S_ref.reset(options.groups);
stride_S_ref.copy_from_host(stride_S_host_ref.data());
stride_S.reset(options.groups);
stride_S.copy_from_host(stride_S_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
problem_sizes.reset(options.groups);
// Reverse MN -> NM for SwapAB
for (int32_t i = 0; i < options.groups; ++i) {
auto [M, N, K] = options.problem_sizes_host[i];
options.problem_sizes_host[i] = make_tuple(N, M, K);
}
problem_sizes.copy_from_host(options.problem_sizes_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Gemm>
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
exit(-1);
}
return arguments;
}
bool verify(Options const& options) {
bool passed = true;
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
ScheduleRef
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
const ElementD epsilon(1e-2f);
const ElementD non_zero_floor(1e-4f);
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
if (M == 0) {
continue;
}
else {
StrideA_verif stride_A_verif;
StrideB_verif stride_B_verif;
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
//
// Compute reference output
//
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
}
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
std::cout << "We passed all checks\n";
// Check if output from CUTLASS kernel and reference kernel are equal or not
MixedDtypeResult result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
if (!result.passed) {
exit(-1);
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
std::cout << "Running in no scale mode." << std::endl;
run<GemmConvertOnly>(options, false);
}
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
std::cout << "Running in group scale mode." << std::endl;
run<GemmScaleOnly>(options, false);
}
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,112 @@
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=100 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=2.0 --beta=2.0 --groups=100 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=0.25 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=16 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=100 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=100 --iterations=0) # Small problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) # Random problem sizes
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
cutlass_example_add_executable(
69_hopper_mixed_dtype_grouped_gemm
69_hopper_mixed_dtype_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
)
cutlass_example_add_executable(
69_hopper_int4_fp8_grouped_gemm
69_hopper_int4_fp8_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
)
cutlass_example_add_executable(
69_hopper_int4_bf16_grouped_gemm
69_hopper_int4_bf16_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
)

View File

@ -0,0 +1,194 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <vector>
#include <fstream>
#include <stdexcept>
#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp"
template<class QuantType>
class GroupedMixedDtypeOptions : public MixedDtypeOptions {
public:
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int,int,int>>;
using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape;
int groups = 6;
int c = 512;
std::string benchmark_path;
std::vector<UnderlyingProblemShape> problem_sizes_host;
GroupedMixedDtypeOptions() : MixedDtypeOptions()
{
m = 1024;
n = 2048;
k = 512;
};
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("c", c);
MixedDtypeOptions::parse(argc, args);
problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems();
}
std::ostream& print_usage(std::ostream& out) const {
out << "69_hopper_mixed_dtype_grouped_gemm\n\n"
<< "Options:\n"
<< " --help Display this usage statement\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems\n"
<< " --mode=<int> The mode to run the gemm\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations\n"
<< " --warmup=<int> Number of warmup iterations\n"
<< " --benchmark=<str> Executes a benchmark problem size\n";
return out;
}
double gflops(double runtime_s) const {
uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL,
[](uint64_t sum, const UnderlyingProblemShape& problem) {
return sum + static_cast<uint64_t>(cute::get<0>(problem)) *
static_cast<uint64_t>(cute::get<1>(problem)) *
static_cast<uint64_t>(cute::get<2>(problem));
});
return (2.0 * fmas) / (runtime_s * 1e9);
}
private:
static constexpr int tma_alignment_bits = 128;
const int alignment = tma_alignment_bits / cutlass::sizeof_bits<QuantType>::value;
std::vector<UnderlyingProblemShape> randomize_problems(cutlass::CommandLine& cmd) {
std::vector<UnderlyingProblemShape> problems;
problems.reserve(groups);
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
for (int i = 0; i < groups; ++i) {
int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1);
int n = (cmd_line_n >= 0) ? cmd_line_n : this->n;
int k = (cmd_line_k >= 0) ? cmd_line_k : this->k;
if (k % alignment != 0) {
throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment));
}
problems.push_back({m, n, k});
}
return problems;
}
std::vector<UnderlyingProblemShape> load_benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file) {
throw std::runtime_error("Failed to open benchmark file: " + benchmark_path);
}
std::vector<UnderlyingProblemShape> problems;
int idx;
std::string extent_str;
while (file >> idx >> extent_str) {
if (idx < 0 || extent_str.empty()) break;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
cutlass::gemm::GemmCoord extent;
for (int i = 0; i < std::min(3, static_cast<int>(tokens.size())); ++i) {
int x = std::stoi(tokens[i]);
extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x;
}
if (extent.product()) {
problems.push_back({extent.m(), extent.n(), extent.k()});
}
}
groups = static_cast<int>(problems.size());
return problems;
}
};
template <class QuantType, class Gemm, class ElementAccumulator>
void grouped_mixed_dtype_profiling(
Gemm& gemm,
const GroupedMixedDtypeOptions<QuantType>& options,
MixedDtypeResult& result,
const std::vector<ElementAccumulator>& alpha_host,
const std::vector<ElementAccumulator>& beta_host) {
if (options.iterations <= 0) return;
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
std::vector<float> runtimes;
runtimes.reserve(options.iterations);
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
cudaEventRecord(start);
CUTLASS_CHECK(gemm.run());
cudaEventRecord(stop);
cudaEventSynchronize(stop);
if (iter >= options.warmup) {
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
runtimes.push_back(milliseconds);
}
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Sizes, Alpha, Beta\n";
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
}
std::cout << " Groups : " << options.groups << '\n'
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
<< " GFLOPS : " << result.gflops << '\n';
}

View File

@ -124,13 +124,14 @@ constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // A
using ElementAccumulator = float; // Element type for internal accumulation
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
constexpr int OutputSFVectorSize = 16;
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
cutlass::epilogue::thread::SiLu,
OutputSFVectorSize,
ElementD,
ElementAccumulator,
ElementSF,
ElementSFD,
LayoutC,
ElementC>;

View File

@ -466,7 +466,6 @@ struct ExampleRunner {
int max_seqlen_kv = 0;
for (auto e : seqlen_kv) {
// if (options.varlen) std::cout << "seqlen " << e << std::endl;
max_seqlen_kv = std::max(e, max_seqlen_kv);
}

View File

@ -29,11 +29,11 @@
set_property(
SOURCE 77_blackwell_fmha.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
set_property(
SOURCE 77_blackwell_fmha_gen.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)

View File

@ -529,7 +529,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _);
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
@ -822,9 +821,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
// tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
// tTMEM_LOADsO_i.data() = tTMEM_LOADsO_i.data().get() + sO.layout()(_0{}, i * kCorrectionTileSize, _0{});
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
@ -939,8 +936,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
//tTMrO(j) = scale * tTMrO(j);
//tTMrO(j+1) = scale * tTMrO(j+1);
}
copy_out(i);

View File

@ -538,7 +538,6 @@ struct Sm100FmhaGenMainloopWarpspecialized {
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _);
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
@ -956,8 +955,6 @@ struct Sm100FmhaGenMainloopWarpspecialized {
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
//tTMrO(j) = scale * tTMrO(j);
//tTMrO(j+1) = scale * tTMrO(j+1);
}
copy_out(i);

View File

@ -188,21 +188,15 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
// Q1
int q0_index = get<0>(blk_coord);
// pipeline_q.producer_acquire(pipeline_q_producer_state);
// copy_with_limit(tiled_copy_q, tQcQ, limitQ, tQgQ, tQsQ(_, _, _, _, pipeline_q_producer_state.index());
auto load_q = [&](int q_index, auto& state) {
pipeline_q.producer_acquire(state);
// using Vec = Element;
// auto vzero = Element(0);
// q is always loaded masked
using Vec = uint128_t;
Vec vzero = uint128_t(0, 0);
//auto src = recast<Vec>(tQgQ(_, _, _, _, q_index));
auto src = recast<Vec>(tQgQ(_, _, _, _));
auto dst = recast<Vec>(tQsQ(_, _, _, _, state.index()));
// auto c = tQcQ(_, _, _, _, q_index);
auto c = tQcQ(_, _, _, _);
int vlen = sizeof(Vec) / sizeof(Element);
CUTLASS_PRAGMA_UNROLL
@ -220,7 +214,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
};
load_q(q0_index, pipeline_q_producer_state);
// pipeline_q.producer_commit(pipeline_q_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_q_producer_state;
auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{}));
@ -287,8 +280,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index()));
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
} else {
// using Vec = Element;
// auto vzero = Element(0);
using Vec = uint128_t;
Vec vzero = uint128_t(0, 0);
auto src = recast<Vec>(tKgK(_, _, _, _, k_index));
@ -322,8 +313,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index()));
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
} else {
// using Vec = Element;
// auto vzero = Element(0);
using Vec = uint128_t;
Vec vzero = uint128_t(0, 0);
auto src = recast<Vec>(tVgV(_, _, _, _, v_index));

View File

@ -149,8 +149,6 @@ void __global__ fmha_fwd_gen_reference_kernel(
__syncthreads();
for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) {
// printf("O[%d,%d,%d] = %f\n", idx_d, idx_h, idx_b, mS[idx_d]);
mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast<typename TensorO::value_type>(mS[idx_d]);
}
}

View File

@ -0,0 +1,475 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Blackwell CUTLASS GEMM example for FastFP32 (using BF16 to emulate SGEMM).
This example demonstrates how to run an emulated SGEMM with BF16x9 on an NVIDIA GPU that supports
NVIDIA's Blackwell architecture (SM100a). Using BF16x9 leverages tensor cores, providing much
greater throughput compared to SIMT instructions.
To emulate SGEMM using BF16x9, the A and B matrices are decomposed to three lower precision elements:
a = a1 + a2 + a3
b = b1 + b2 + b3
One FP32 MAC is equivalent to 9 MACs using BF16:
a * b + c = a1*b1 + a1*b2 + a1*b3 + a2*b1 + a2*b2 + a2*b3 + a3*b1 + a3*b2 + a3*b3 + c
Example 27 demonstrates a similar technique for emulated SGEMM using TF32 with the Ampere architecture.
Usage:
$ ./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm --m=8192 --n=8192 --k=8192
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = float; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = float; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// Kernel Perf config
using ClusterTileShape = Shape<_256,_128,_16>; // Cluster-level tile shape
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
using CtaTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // Threadblock-level tile shape
using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
CtaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
// Build the mainloop
// Note: Emulated BF16x9 kernels need to manually specify a mainloop schedule and cannot use KernelScheduleAuto
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecializedFastFP32SmemSm100;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopSchedule
>::CollectiveOp;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(8192), n(8192), k(8192),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "78_blackwell_emulated_bf16x9_gemm\n\n"
<< " Blackwell emulated BF16x9 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "78_blackwell_emulated_bf16x9_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = Element(2);
scope_min = Element(0);
} else if (bits_input <= 8) {
scope_max = Element(2);
scope_min = Element(-2);
} else {
scope_max = Element(8);
scope_min = Element(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
block_A.reset(options.m * options.k);
block_B.reset(options.k * options.n);
block_C.reset(options.m * options.n);
block_D.reset(options.m * options.n);
block_ref_D.reset(options.m * options.n);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
return arguments;
}
bool verify(const Options &options) {
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{options.m, options.n, options.k},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 100a.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
cutlass_example_add_executable(
78_blackwell_emulated_bf16x9_gemm
78_blackwell_emulated_bf16x9_gemm.cu
)
endif()

View File

@ -146,14 +146,16 @@ foreach(EXAMPLE
64_ada_fp8_gemm_grouped
65_distributed_gemm
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
69_hopper_mixed_dtype_grouped_gemm
70_blackwell_gemm
71_blackwell_gemm_with_collective_builder
72_blackwell_narrow_precision_gemm
73_blackwell_gemm_preferred_cluster
73_blackwell_gemm_preferred_cluster
74_blackwell_gemm_streamk
75_blackwell_grouped_gemm
76_blackwell_conv
77_blackwell_fmha
78_blackwell_emulated_bf16x9_gemm
)
add_subdirectory(${EXAMPLE})

View File

@ -246,8 +246,6 @@
Hopper GEMM kernel with Top-K and softmax epilogue fusion.
[//]: #
* [70_blackwell_gemm](70_blackwell_gemm)
Simple dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
@ -280,8 +278,6 @@
Blackwell SM100 FMHA kernel
[//]: #
# CuTe - Programming Examples
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).