v3.8.0 update (#2082)
* 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -483,18 +483,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -566,17 +566,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -103,11 +103,10 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
|
||||
int constexpr NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale;
|
||||
@ -438,14 +437,15 @@ void initialize(Options const& options) {
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
@ -613,17 +613,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -107,11 +107,10 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
@ -349,10 +348,10 @@ void initialize(Options const& options) {
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
unify_quant_encoding(block_B, block_B_modified);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_packed_scale(block_scale, block_scale_packed);
|
||||
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
@ -360,12 +359,13 @@ void initialize(Options const& options) {
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
@ -518,17 +518,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -100,6 +100,7 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
@ -322,9 +323,10 @@ void initialize(MixedDtypeOptions const& options) {
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
@ -483,17 +485,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -60,8 +60,8 @@ struct MixedDtypeOptions {
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int iterations = 100;
|
||||
int warmup = 10;
|
||||
int mode = 1;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
@ -228,22 +228,18 @@ bool initialize_scale(
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(1.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
else {
|
||||
// If no scales, initialize with 1 so we can use the same kernel to dequantize the data
|
||||
float scope_max = 1.0f, scope_min = 1.0f;
|
||||
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
scope_max = max_dequant_val / elt_max_f;
|
||||
scope_min = min_dequant_val / elt_max_f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -253,139 +249,14 @@ bool initialize_zero(
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
// If no bias, initialize with 0 so we can use the same kernel to dequantize the data
|
||||
float scope_max = 0.0f, scope_min = 0.0f;
|
||||
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
scope_max = 2.0f;
|
||||
scope_min = -2.0f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Dequantize the weights for verification
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
// Represent the full tensors to gmem elements.
|
||||
// These are expected to have shape [MN, K, L]
|
||||
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
||||
auto init_quantized_iterator = [&]() {
|
||||
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
|
||||
return cute::make_gmem_ptr(q_buffer);
|
||||
} else {
|
||||
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
|
||||
}
|
||||
};
|
||||
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
||||
|
||||
// Tile across the block
|
||||
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
||||
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
||||
|
||||
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
||||
|
||||
for (int ii = 0; ii < num_iters; ++ii) {
|
||||
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
||||
if (thread_offset < cute::size<0>(operand_layout)) {
|
||||
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
|
||||
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
|
||||
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
constexpr int tpb = 128;
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
@ -1,211 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
template<typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
if constexpr (!cute::is_unsigned_v<T>) {
|
||||
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
||||
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
||||
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
||||
}
|
||||
else {
|
||||
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
||||
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
||||
}
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
packed_scale_t() = default;
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(packed_scale_t const& rhs) const {
|
||||
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(packed_scale_t const& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() + rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() - rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() * rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() / rhs.get());
|
||||
}
|
||||
|
||||
private:
|
||||
using Storage = uint32_t;
|
||||
using Stage = uint8_t;
|
||||
|
||||
Storage storage[2] {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Storage pack4(T c1, T c2, T c3, T c4) {
|
||||
Storage result = 0;
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
||||
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
||||
return result;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get() const {
|
||||
auto stage = static_cast<Stage>(storage[0] >> 8);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get(int idx) const {
|
||||
Stage stage;
|
||||
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
||||
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Helpers to initialize scale lookup table
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
bool unify_quant_encoding(
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
|
||||
|
||||
using StorageType = cutlass::int4b_t::Storage;
|
||||
|
||||
if (block_in.size() != block_out.size()) {
|
||||
std::cerr << "block_in and block_out must have same size.\n";
|
||||
return false;
|
||||
}
|
||||
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
||||
std::vector<StorageType> data(block_in.size() / pack);
|
||||
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
|
||||
|
||||
for (auto&& d : data) {
|
||||
StorageType out = 0;
|
||||
StorageType mask = 0x0f;
|
||||
for (int i = 0; i < pack; ++i) {
|
||||
cutlass::int4b_t curr;
|
||||
curr.storage = (d >> (i * 4)) & 0x0f;
|
||||
switch (curr) {
|
||||
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
||||
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
||||
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
||||
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
||||
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
||||
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
||||
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
||||
default: break;
|
||||
}
|
||||
out |= (curr.storage << (4 * i)) & mask;
|
||||
mask <<= 4;
|
||||
}
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ElementScale>
|
||||
bool initialize_packed_scale(
|
||||
cutlass::DeviceAllocation<ElementScale> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
|
||||
|
||||
std::vector<ElementScale> data_in(block_in.size());
|
||||
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
|
||||
try {
|
||||
block_in.copy_to_host(data_in.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < block_in.size(); ++i)
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -1,162 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/mma_sm90.hpp"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
||||
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
||||
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
||||
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
||||
// In addition, we can reorder the values across several MMA instructions to get even wider
|
||||
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
||||
// more optimal conversion instruction sequences (ValLayout parameter).
|
||||
template<class ElementMma,
|
||||
class AtomLayout = cute::Layout<cute::_1>,
|
||||
class ValLayout = cute::Layout<cute::_1>>
|
||||
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
||||
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
||||
|
||||
// 1. Choose an MMA atom to access TV layout and MN shape
|
||||
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
||||
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
||||
using MmaTraits = MMA_Traits<MmaAtom>;
|
||||
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
||||
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
||||
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
||||
|
||||
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
||||
// Note: this assumes A is partitioned between warps along M mode
|
||||
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
||||
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
||||
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
||||
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
||||
|
||||
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
||||
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
||||
|
||||
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
||||
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
||||
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
||||
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
||||
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
||||
|
||||
return layout_atom;
|
||||
}
|
||||
|
||||
template<class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
||||
__global__ void reorder_tensor_kernel(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D,
|
||||
TiledCopy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
|
||||
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
|
||||
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
||||
Tensor tS = thread_copy.partition_S(gS);
|
||||
Tensor tD = thread_copy.partition_D(gD);
|
||||
|
||||
copy(tiled_copy, tS, tD);
|
||||
}
|
||||
|
||||
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
||||
|
||||
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
||||
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
||||
auto has_major_mode = [](auto s) {
|
||||
return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; });
|
||||
};
|
||||
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
||||
"Could not find stride-1 mode in destination layout");
|
||||
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
|
||||
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
||||
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
||||
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
||||
|
||||
// Make a tiled copy with a simple row-major thread order and above layout
|
||||
int constexpr NumThreads = 128;
|
||||
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
||||
|
||||
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
||||
using TileShape = Shape<_16>;
|
||||
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
||||
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
||||
|
||||
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T const* src,
|
||||
LayoutSrc const& layout_src,
|
||||
T * dst,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
||||
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T * data,
|
||||
LayoutSrc const& layout_src,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
||||
}
|
||||
@ -513,17 +513,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -731,17 +731,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -768,16 +768,26 @@ int main(int argc, char const** argv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) ||
|
||||
(props.major != 8 && props.minor != 9)) {
|
||||
bool satisfied;
|
||||
if (props.major < 10) {
|
||||
// Pre-Blackwell
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4);
|
||||
satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9);
|
||||
}
|
||||
else {
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8);
|
||||
}
|
||||
|
||||
if (!satisfied) {
|
||||
//
|
||||
// This example requires an NVIDIA Ada-architecture GPU.
|
||||
// This example requires an NVIDIA GPU with compute capability 8.9 or greater.
|
||||
//
|
||||
|
||||
std::cout
|
||||
<< "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture "
|
||||
<< "and CUDA toolkit version 12.4 or later.\n";
|
||||
<< "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater "
|
||||
<< "and CUDA toolkit version 12.4 or later"
|
||||
<< " (12.8 or later needed for SM100+)"
|
||||
<< std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -504,17 +504,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -570,18 +570,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
|
||||
@ -469,18 +469,12 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -31,28 +31,20 @@
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster.
|
||||
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
|
||||
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
|
||||
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
|
||||
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 \
|
||||
--save_aux=false --save_amax=false \
|
||||
|
||||
@ -34,4 +34,4 @@ cutlass_example_add_executable(
|
||||
cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
|
||||
)
|
||||
)
|
||||
|
||||
@ -191,7 +191,7 @@ void gett_mainloop(
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
|
||||
@ -0,0 +1,818 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
NOTE: Write docu
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
// using ValueShuffle = Layout<_1>; // no value reordering
|
||||
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
|
||||
int constexpr NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
|
||||
|
||||
using ElementZero = cutlass::bfloat16_t;
|
||||
using ElementScale = cutlass::bfloat16_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnlyShuffled>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnlyShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelConvertOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelConvertOnly::InternalStrideD;
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : GroupedMixedDtypeOptions<QuantType> {
|
||||
using Base = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
bool shuffle = true;
|
||||
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->Base::parse(argc, args);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "69_hopper_int4_bf16_grouped_gemm\n\n"
|
||||
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "69_hopper_int4_bf16_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementScale *> ptr_scale_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale.reset(options.groups);
|
||||
ptr_scale.copy_from_host(ptr_scale_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
const int scale_k = 1;
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
|
||||
}
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
|
||||
if (options.shuffle) {
|
||||
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
|
||||
if (i == 0) {
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered_host[0]);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
layout_B_reordered.reset(options.groups);
|
||||
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
|
||||
}
|
||||
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
// NOTE: add GemmScaleWithZeroPointShuffled
|
||||
if constexpr (cute::is_same_v<Gemm, GemmConvertOnlyShuffled> ||
|
||||
cute::is_same_v<Gemm, GemmScaleOnlyShuffled>) {
|
||||
// offline swizzling is enabled.
|
||||
return layout_B_reordered.get();
|
||||
}
|
||||
else {
|
||||
return stride_B.get();
|
||||
}
|
||||
}();
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmConvertOnlyShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmConvertOnly>(options, false);
|
||||
}
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmScaleOnlyShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,753 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
NOTE: Write docu
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
|
||||
|
||||
using ElementZero = cutlass::float_e4m3_t;
|
||||
using ElementScale = cutlass::float_e4m3_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<ElementB> block_B_modified;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const cutlass::Array<ElementScale, 8> *> ptr_scale_packed;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : GroupedMixedDtypeOptions<QuantType> {
|
||||
using Base = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
bool shuffle = true;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->Base::parse(argc, args);
|
||||
|
||||
mode = 1; // override the mode value to always be scale only mode
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "69_hopper_int4_fp8_grouped_gemm\n\n"
|
||||
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --c=<int> The size of each chunk for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "69_hopper_int4_fp8_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_modified.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_scale_packed.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options& options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<cutlass::Array<ElementScale, 8> *> ptr_scale_packed_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B_modified.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_packed_host.at(i) = block_scale_packed.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale_packed.reset(options.groups);
|
||||
ptr_scale_packed.copy_from_host(ptr_scale_packed_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
|
||||
if (options.shuffle) {
|
||||
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B_modified.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
|
||||
if (i == 0) {
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered_host[0]);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
layout_B_reordered.reset(options.groups);
|
||||
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
|
||||
}
|
||||
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
if constexpr (cute::is_same_v<Gemm, GemmShuffled>) { // offline swizzling is enabled.
|
||||
return layout_B_reordered.get();
|
||||
}
|
||||
else {
|
||||
return stride_B.get();
|
||||
}
|
||||
}();
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
const int scale_k = 1;
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,678 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
NOTE: Write docu
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::float_e5m2_t;
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = cutlass::bfloat16_t;
|
||||
using ElementScale = cutlass::bfloat16_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = typename GemmConvertOnly::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename GemmConvertOnly::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename GemmConvertOnly::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename GemmConvertOnly::GemmKernel::InternalStrideD;
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using Options = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
const int scale_k = 1;
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementScale *> ptr_scale_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale.reset(options.groups);
|
||||
ptr_scale.copy_from_host(ptr_scale_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_quant_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
const int scale_k = 1;
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group: " << i << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
run<GemmConvertOnly>(options, false);
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
112
examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt
Normal file
112
examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=100 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=2.0 --beta=2.0 --groups=100 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=0.25 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=16 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=100 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=100 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) # Random problem sizes
|
||||
|
||||
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
|
||||
|
||||
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
69_hopper_mixed_dtype_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_int4_fp8_grouped_gemm
|
||||
69_hopper_int4_fp8_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_int4_bf16_grouped_gemm
|
||||
69_hopper_int4_bf16_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
)
|
||||
@ -0,0 +1,194 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp"
|
||||
|
||||
template<class QuantType>
|
||||
class GroupedMixedDtypeOptions : public MixedDtypeOptions {
|
||||
public:
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int,int,int>>;
|
||||
using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape;
|
||||
|
||||
int groups = 6;
|
||||
int c = 512;
|
||||
std::string benchmark_path;
|
||||
std::vector<UnderlyingProblemShape> problem_sizes_host;
|
||||
|
||||
GroupedMixedDtypeOptions() : MixedDtypeOptions()
|
||||
{
|
||||
m = 1024;
|
||||
n = 2048;
|
||||
k = 512;
|
||||
};
|
||||
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
MixedDtypeOptions::parse(argc, args);
|
||||
|
||||
problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems();
|
||||
}
|
||||
|
||||
std::ostream& print_usage(std::ostream& out) const {
|
||||
out << "69_hopper_mixed_dtype_grouped_gemm\n\n"
|
||||
<< "Options:\n"
|
||||
<< " --help Display this usage statement\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems\n"
|
||||
<< " --mode=<int> The mode to run the gemm\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations\n"
|
||||
<< " --warmup=<int> Number of warmup iterations\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
double gflops(double runtime_s) const {
|
||||
uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL,
|
||||
[](uint64_t sum, const UnderlyingProblemShape& problem) {
|
||||
return sum + static_cast<uint64_t>(cute::get<0>(problem)) *
|
||||
static_cast<uint64_t>(cute::get<1>(problem)) *
|
||||
static_cast<uint64_t>(cute::get<2>(problem));
|
||||
});
|
||||
return (2.0 * fmas) / (runtime_s * 1e9);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr int tma_alignment_bits = 128;
|
||||
const int alignment = tma_alignment_bits / cutlass::sizeof_bits<QuantType>::value;
|
||||
|
||||
std::vector<UnderlyingProblemShape> randomize_problems(cutlass::CommandLine& cmd) {
|
||||
std::vector<UnderlyingProblemShape> problems;
|
||||
problems.reserve(groups);
|
||||
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1);
|
||||
int n = (cmd_line_n >= 0) ? cmd_line_n : this->n;
|
||||
int k = (cmd_line_k >= 0) ? cmd_line_k : this->k;
|
||||
|
||||
if (k % alignment != 0) {
|
||||
throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment));
|
||||
}
|
||||
problems.push_back({m, n, k});
|
||||
}
|
||||
return problems;
|
||||
}
|
||||
|
||||
std::vector<UnderlyingProblemShape> load_benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Failed to open benchmark file: " + benchmark_path);
|
||||
}
|
||||
|
||||
std::vector<UnderlyingProblemShape> problems;
|
||||
int idx;
|
||||
std::string extent_str;
|
||||
|
||||
while (file >> idx >> extent_str) {
|
||||
if (idx < 0 || extent_str.empty()) break;
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
for (int i = 0; i < std::min(3, static_cast<int>(tokens.size())); ++i) {
|
||||
int x = std::stoi(tokens[i]);
|
||||
extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problems.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problems.size());
|
||||
return problems;
|
||||
}
|
||||
};
|
||||
|
||||
template <class QuantType, class Gemm, class ElementAccumulator>
|
||||
void grouped_mixed_dtype_profiling(
|
||||
Gemm& gemm,
|
||||
const GroupedMixedDtypeOptions<QuantType>& options,
|
||||
MixedDtypeResult& result,
|
||||
const std::vector<ElementAccumulator>& alpha_host,
|
||||
const std::vector<ElementAccumulator>& beta_host) {
|
||||
|
||||
if (options.iterations <= 0) return;
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
std::vector<float> runtimes;
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta\n";
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << '\n'
|
||||
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
|
||||
<< " GFLOPS : " << result.gflops << '\n';
|
||||
}
|
||||
@ -124,13 +124,14 @@ constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // A
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
|
||||
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
|
||||
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
|
||||
constexpr int OutputSFVectorSize = 16;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementAccumulator,
|
||||
ElementSF,
|
||||
ElementSFD,
|
||||
LayoutC,
|
||||
ElementC>;
|
||||
|
||||
|
||||
@ -466,7 +466,6 @@ struct ExampleRunner {
|
||||
|
||||
int max_seqlen_kv = 0;
|
||||
for (auto e : seqlen_kv) {
|
||||
// if (options.varlen) std::cout << "seqlen " << e << std::endl;
|
||||
max_seqlen_kv = std::max(e, max_seqlen_kv);
|
||||
}
|
||||
|
||||
|
||||
@ -29,11 +29,11 @@
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_fmha.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_fmha_gen.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
|
||||
@ -529,7 +529,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _);
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
@ -822,9 +821,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
|
||||
// tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||
// tTMEM_LOADsO_i.data() = tTMEM_LOADsO_i.data().get() + sO.layout()(_0{}, i * kCorrectionTileSize, _0{});
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||
|
||||
@ -939,8 +936,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
//tTMrO(j) = scale * tTMrO(j);
|
||||
//tTMrO(j+1) = scale * tTMrO(j+1);
|
||||
}
|
||||
|
||||
copy_out(i);
|
||||
|
||||
@ -538,7 +538,6 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _);
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
@ -956,8 +955,6 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
//tTMrO(j) = scale * tTMrO(j);
|
||||
//tTMrO(j+1) = scale * tTMrO(j+1);
|
||||
}
|
||||
|
||||
copy_out(i);
|
||||
|
||||
@ -188,21 +188,15 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
|
||||
// Q1
|
||||
int q0_index = get<0>(blk_coord);
|
||||
// pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
|
||||
// copy_with_limit(tiled_copy_q, tQcQ, limitQ, tQgQ, tQsQ(_, _, _, _, pipeline_q_producer_state.index());
|
||||
auto load_q = [&](int q_index, auto& state) {
|
||||
pipeline_q.producer_acquire(state);
|
||||
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
// q is always loaded masked
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
//auto src = recast<Vec>(tQgQ(_, _, _, _, q_index));
|
||||
auto src = recast<Vec>(tQgQ(_, _, _, _));
|
||||
auto dst = recast<Vec>(tQsQ(_, _, _, _, state.index()));
|
||||
// auto c = tQcQ(_, _, _, _, q_index);
|
||||
auto c = tQcQ(_, _, _, _);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -220,7 +214,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
};
|
||||
|
||||
load_q(q0_index, pipeline_q_producer_state);
|
||||
// pipeline_q.producer_commit(pipeline_q_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{}));
|
||||
@ -287,8 +280,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tKgK(_, _, _, _, k_index));
|
||||
@ -322,8 +313,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tVgV(_, _, _, _, v_index));
|
||||
|
||||
@ -149,8 +149,6 @@ void __global__ fmha_fwd_gen_reference_kernel(
|
||||
|
||||
__syncthreads();
|
||||
for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) {
|
||||
|
||||
// printf("O[%d,%d,%d] = %f\n", idx_d, idx_h, idx_b, mS[idx_d]);
|
||||
mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast<typename TensorO::value_type>(mS[idx_d]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,475 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A Blackwell CUTLASS GEMM example for FastFP32 (using BF16 to emulate SGEMM).
|
||||
|
||||
This example demonstrates how to run an emulated SGEMM with BF16x9 on an NVIDIA GPU that supports
|
||||
NVIDIA's Blackwell architecture (SM100a). Using BF16x9 leverages tensor cores, providing much
|
||||
greater throughput compared to SIMT instructions.
|
||||
|
||||
To emulate SGEMM using BF16x9, the A and B matrices are decomposed to three lower precision elements:
|
||||
a = a1 + a2 + a3
|
||||
b = b1 + b2 + b3
|
||||
|
||||
One FP32 MAC is equivalent to 9 MACs using BF16:
|
||||
a * b + c = a1*b1 + a1*b2 + a1*b3 + a2*b1 + a2*b2 + a2*b3 + a3*b1 + a3*b2 + a3*b3 + c
|
||||
|
||||
Example 27 demonstrates a similar technique for emulated SGEMM using TF32 with the Ampere architecture.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = float; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = float; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using ClusterTileShape = Shape<_256,_128,_16>; // Cluster-level tile shape
|
||||
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using CtaTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // Threadblock-level tile shape
|
||||
using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
CtaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
// Note: Emulated BF16x9 kernels need to manually specify a mainloop schedule and cannot use KernelScheduleAuto
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecializedFastFP32SmemSm100;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "78_blackwell_emulated_bf16x9_gemm\n\n"
|
||||
<< " Blackwell emulated BF16x9 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "78_blackwell_emulated_bf16x9_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 100a.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
36
examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt
Normal file
36
examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
78_blackwell_emulated_bf16x9_gemm
|
||||
78_blackwell_emulated_bf16x9_gemm.cu
|
||||
)
|
||||
endif()
|
||||
@ -146,14 +146,16 @@ foreach(EXAMPLE
|
||||
64_ada_fp8_gemm_grouped
|
||||
65_distributed_gemm
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
70_blackwell_gemm
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
72_blackwell_narrow_precision_gemm
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
74_blackwell_gemm_streamk
|
||||
75_blackwell_grouped_gemm
|
||||
76_blackwell_conv
|
||||
77_blackwell_fmha
|
||||
78_blackwell_emulated_bf16x9_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -246,8 +246,6 @@
|
||||
|
||||
Hopper GEMM kernel with Top-K and softmax epilogue fusion.
|
||||
|
||||
[//]: #
|
||||
|
||||
* [70_blackwell_gemm](70_blackwell_gemm)
|
||||
|
||||
Simple dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
@ -280,8 +278,6 @@
|
||||
|
||||
Blackwell SM100 FMHA kernel
|
||||
|
||||
[//]: #
|
||||
|
||||
# CuTe - Programming Examples
|
||||
|
||||
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
|
||||
|
||||
Reference in New Issue
Block a user