v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -222,8 +222,8 @@ cutlass_add_cutlass_library(
# files split for parallel compilation
src/reference/gemm_int4.cu
src/reference/block_scaled_gemm_fp4a_vs16.cu
src/reference/block_scaled_gemm_fp4a_vs32.cu
src/reference/block_scaled_gemm_fp4a_vs16.cu
src/reference/block_scaled_gemm_fp4a_vs32.cu
src/reference/block_scaled_gemm_mixed8bitsa.cu
src/reference/gemm_f4_f4_f32.cu
src/reference/gemm_f4_f6_f32.cu

View File

@ -43,7 +43,8 @@
computational overhead
*/
#pragma once
#ifndef CUTLASS_LIBRARY_LIBRARY_H
#define CUTLASS_LIBRARY_LIBRARY_H
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -103,7 +104,7 @@ public:
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
// Originally designed for metadata, but should be useful for FP8/6/4 too.
// Originally designed for metadata, but should be useful for FP8/6/4 too.
virtual Status initialize_with_profiler_workspace(
void const *configuration,
void *host_workspace,
@ -282,6 +283,11 @@ struct GemmUniversalConfiguration {
int device_count{1};
};
enum class Sm90MixedInputWiderOperand {
A = 0,
B = 1
};
struct GemmUniversalArguments {
// NOTE: these are replicated for 3.0 interfaces
gemm::GemmCoord problem_size{};
@ -317,6 +323,18 @@ struct GemmUniversalArguments {
int swizzle_size{1};
int split_k_slices{1};
// For mixed input dtype kernels
bool is_mixed_dtype{false};
Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B};
bool generate_scale_and_zero{false};
bool generate_dequantized_AB{false};
bool *dequantized_AB_ready{nullptr}; // Carry the info back to gemm_operation_profiler.cu
void *Scale{nullptr}; // Scale tensor
void *Zero{nullptr}; // Zero tensor
void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification
void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle
void *packed_Scale{nullptr}; // Packed scale for int4 * fp8
int device_index{0};
bool use_pdl{false};
@ -472,12 +490,16 @@ struct GemmPlanarComplexArrayArguments {
struct GemmGroupedConfiguration {
int problem_count{0};
int threadblock_count{0};
// GemmGroupedConfiguration is passed to initialize(), which
// is responsible for allocating the device-side stride storage.
int64_t* lda;
int64_t* ldb;
int64_t* ldc;
};
struct GemmGroupedArguments {
gemm::GemmCoord *problem_sizes{nullptr};
int problem_count{};
gemm::GemmCoord* problem_sizes{nullptr};
void * ptr_A{nullptr};
void * ptr_B{nullptr};
@ -493,6 +515,18 @@ struct GemmGroupedArguments {
void const *beta{nullptr};
ScalarPointerMode pointer_mode{};
bool use_pdl{false};
gemm::GemmCoord cluster_shape{};
gemm::GemmCoord cluster_shape_fallback{};
// these should really be in the configuration but staying consistent with GEMM
int sm_count{0};
// The user is responsible for allocating storage for problem sizes.
// Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we
// unfortunately need to have both options in this struct, and the
// underlying operation uses the one it needs.
cute::Shape<int, int, int>* problem_sizes_3x;
cute::Shape<int, int, int>* problem_sizes_3x_host;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -880,3 +914,5 @@ struct ReductionArguments {
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif

View File

@ -142,7 +142,7 @@ enum class Provider {
/// Enumeration indicating the kind of operation
enum class OperationKind {
kGemm,
kBlockScaledGemm,
kBlockScaledGemm,
kRankK,
kRank2K,
kTrmm,
@ -152,6 +152,7 @@ enum class OperationKind {
kEqGemm,
kSparseGemm,
kReduction,
kGroupedGemm,
kInvalid
};
@ -270,7 +271,6 @@ enum class RuntimeDatatype {
kStatic,
kE4M3,
kE5M2,
kE3M2,
kE2M3,
kE2M1,

View File

@ -34,7 +34,8 @@
\brief Utilities accompanying the CUTLASS library for interacting with Library types.
*/
#pragma once
#ifndef CUTLASS_LIBRARY_UTIL_H
#define CUTLASS_LIBRARY_UTIL_H
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
@ -213,6 +214,63 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
#define CUDA_CHECK(call) \
do { \
cudaError_t err = (call); \
if (err != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \
<< __FILE__ << ":" << __LINE__ << std::endl; \
return Status::kInvalid; \
} \
} while (0)
// RAII CUDA buffer container
class CudaBuffer {
public:
CudaBuffer() : size_(0), d_ptr_(nullptr) {}
explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) {
cudaError_t err = cudaMalloc(&d_ptr_, size_);
if (err != cudaSuccess) {
throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err)));
}
}
~CudaBuffer() {
if (d_ptr_) {
cudaFree(d_ptr_);
}
}
CudaBuffer(CudaBuffer const&) = delete;
CudaBuffer& operator=(CudaBuffer const&) = delete;
CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) {
other.d_ptr_ = nullptr;
other.size_ = 0;
}
CudaBuffer& operator=(CudaBuffer&& other) noexcept {
if (this != &other) {
if (d_ptr_) {
cudaFree(d_ptr_);
}
d_ptr_ = other.d_ptr_;
size_ = other.size_;
other.d_ptr_ = nullptr;
other.size_ = 0;
}
return *this;
}
void* data() const noexcept { return d_ptr_; }
size_t size() const noexcept { return size_; }
private:
size_t size_;
void* d_ptr_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
@ -220,3 +278,4 @@ NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif

View File

@ -73,7 +73,7 @@ public:
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig;
static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v<typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize;
using ElementSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;

View File

@ -1201,25 +1201,30 @@ public:
GemmOperationBase<Operator_>(name) {
this->description_.gemm_kind = GemmKind::kGrouped;
this->description_.kind = OperationKind::kGroupedGemm;
this->threadblock_count = Operator::sufficient();
}
private:
int threadblock_count;
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
Status construct_arguments_(
OperatorArguments &op_args,
GemmGroupedConfiguration const *config) {
GemmGroupedConfiguration const *config) const {
op_args.problem_count = config->problem_count;
op_args.threadblock_count = config->threadblock_count;
op_args.threadblock_count = threadblock_count;
return Status::kSuccess;
}
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
Status update_arguments_(
OperatorArguments &op_args,
GemmGroupedArguments const *arguments) {
GemmGroupedArguments const *arguments) const {
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
@ -1243,6 +1248,8 @@ protected:
return Status::kErrorInvalidProblem;
}
op_args.threadblock_count = threadblock_count;
op_args.problem_count = arguments->problem_count;
op_args.problem_sizes = arguments->problem_sizes;
op_args.ptr_A = static_cast<ElementA **>(arguments->ptr_A);

View File

@ -36,9 +36,17 @@
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/array.h"
#include "cutlass/array_subbyte.h"
#include "cutlass/library/library.h"
#include "library_internal.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cute/tensor.hpp"
#include <unordered_map>
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -65,7 +73,7 @@ public:
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
private:
protected:
GemmDescription description_;
public:
@ -178,7 +186,23 @@ public:
/// Constructor
GemmUniversal3xOperation(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {}
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
}
private:
int max_active_clusters{};
protected:
@ -227,10 +251,119 @@ protected:
}
};
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
template<template<int, class, class> class Policy, int Stages, class ClusterShape, class KernelSchedule>
static constexpr bool is_mixed_dtype_mainloop_(Policy<Stages, ClusterShape, KernelSchedule> policy) {
return (cute::is_same_v<Policy<Stages, ClusterShape, KernelSchedule>,
cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule>>);
}
template <class DispatchPolicy>
static constexpr bool is_mixed_dtype_mainloop_(DispatchPolicy) {
return false;
}
template <
typename ElementWide,
typename ElementNarrow,
typename ElementScaleMainloop,
class ActualStrideAB,
Sm90MixedInputWiderOperand wider_operand,
bool is_n4w8,
typename ElementScale,
typename ElementZero,
class Layout_SZ>
static void dequantize_encode_(
OperatorArguments &operator_args,
GemmUniversalArguments const *arguments) {
GemmUniversalArguments const *arguments,
cudaStream_t stream,
const int &problem_mn,
const int &problem_k,
const int &options_l,
const int &options_g,
ElementScale *ptr_S,
ElementZero *ptr_Z,
const size_t &SZ_size,
Layout_SZ layout_SZ
) {
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
auto *ptr_dequantized_AB = static_cast<ElementWide *>(arguments->dequantized_AB);
const ElementNarrow *ptr_AB = nullptr;
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
ptr_AB = static_cast<const ElementNarrow *>(arguments->B);
}
else {
ptr_AB = static_cast<const ElementNarrow *>(arguments->A);
}
dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream);
if constexpr(is_n4w8) {
size_t AB_size = cute::size(layout_AB);
cutlass::int4b_t *encoded_AB = static_cast<cutlass::int4b_t *>(arguments->encoded_AB);
unified_encode_int4b(ptr_AB, encoded_AB, AB_size);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.ptr_B = static_cast<ElementNarrow const *>(encoded_AB);
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementNarrow const *>(encoded_AB);
}
ElementScaleMainloop *ptr_packed_Scale = static_cast<ElementScaleMainloop *>(arguments->packed_Scale);
pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size);
}
}
template <
typename ElementAB,
class ActualStrideAB,
class LayoutAB_Reordered,
class LayoutAtomQuant,
Sm90MixedInputWiderOperand wider_operand>
static void handle_shuffle_tensor_(
OperatorArguments &operator_args,
GemmUniversalArguments const *arguments,
const int &problem_mn,
const int &problem_k,
const int &options_l) {
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.dB = layout_AB_reordered;
}
else {
operator_args.mainloop.dA = layout_AB_reordered;
}
if (arguments->generate_dequantized_AB) {
size_t AB_size = cute::size(layout_AB);
ElementAB *AB_reordered = cutlass::device_memory::allocate<ElementAB>(AB_size);
const ElementAB *AB_src = nullptr;
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_B);
}
else {
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_A);
}
reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered);
ElementAB *AB_dst = static_cast<ElementAB *>(arguments->encoded_AB);
cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size);
cutlass::device_memory::free(AB_reordered);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.ptr_B = AB_dst;
}
else {
operator_args.mainloop.ptr_A = AB_dst;
}
}
}
/// Constructs the arguments structure given the configuration and arguments
Status update_arguments_(
OperatorArguments& operator_args,
GemmUniversalArguments const* arguments,
cudaStream_t stream = nullptr) const {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
@ -286,24 +419,173 @@ protected:
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
// Stride{A,B} is a Layout if and only if:
// (1) This is a mixed dtype kernel, and
// (2) This mixed dtype kernel is using shuffling, and
// (3) sizeof(narrow_type) == 4 or 8 bits, and
// (4) sizeof(wide_type) == 16 bits.
// If A/B has the narrow data type, Stride{A/B} will be a Layout
constexpr bool is_StrideA_Layout = cute::is_layout<typename CollectiveMainloop::StrideA>::value;
constexpr bool is_StrideB_Layout = cute::is_layout<typename CollectiveMainloop::StrideB>::value;
static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout");
if constexpr(!is_StrideA_Layout) {
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
}
if constexpr(!is_StrideB_Layout) {
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
arguments->ldb, arguments->batch_stride_B);
}
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
arguments->ldc, arguments->batch_stride_C);
operator_args.epilogue.dD = operator_args.epilogue.dC;
using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy;
if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{})) {
int problem_m = arguments->problem_size.m();
int problem_n = arguments->problem_size.n();
int problem_k = arguments->problem_size.k();
int options_l = arguments->batch_count;
constexpr Sm90MixedInputWiderOperand wider_operand =
(cutlass::sizeof_bits<ElementA>::value > cutlass::sizeof_bits<ElementB>::value) ?
Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B;
using ElementWide = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementA, ElementB>;
using ElementNarrow = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementB, ElementA>;
constexpr bool has_scale = !std::is_same_v<typename CollectiveMainloop::ElementScale, void>;
constexpr bool has_zero = !std::is_same_v<typename CollectiveMainloop::ElementZero, void>;
if constexpr(has_scale) {
int options_g = problem_k;
int scale_k = (problem_k + options_g - 1) / options_g;
constexpr bool is_A4B8 = (
cutlass::is_same_v<ElementA, cutlass::int4b_t> &&
(cutlass::is_same_v<ElementB, cutlass::float_e4m3_t> ||
cutlass::is_same_v<ElementB, cutlass::float_e5m2_t>));
constexpr bool is_A8B4 = (
cutlass::is_same_v<ElementB, cutlass::int4b_t> &&
(cutlass::is_same_v<ElementA, cutlass::float_e4m3_t> ||
cutlass::is_same_v<ElementA, cutlass::float_e5m2_t>));
constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4;
// In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element
using ElementScaleMainloop = typename CollectiveMainloop::ElementScale;
using ElementScale = typename UnderlyingElement<typename CollectiveMainloop::ElementScale>::type;
using StrideS = typename CollectiveMainloop::StrideScale;
// In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S
using ElementZero = std::conditional_t<
has_zero,
typename CollectiveMainloop::ElementZero,
ElementScale
>;
const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m;
const size_t SZ_size = static_cast<size_t>(SZ_1st_dim * scale_k * options_l);
auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l);
ElementScale *ptr_S = static_cast<ElementScale *>(arguments->Scale);
ElementZero *ptr_Z = static_cast<ElementZero *>(arguments->Zero);
// 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled
if (arguments->generate_scale_and_zero) {
// Need to fix max_dequant_val and min_dequant_val?
const float elt_max_f = float(cutlass::platform::numeric_limits<ElementScale>::max());
const float max_dequant_val = elt_max_f * 0.25f;
const float min_dequant_val = 0.5f;
const float scale_max = max_dequant_val / elt_max_f;
const float scale_min = min_dequant_val / elt_max_f;
uint64_t seed = 2023;
cutlass::reference::device::BlockFillRandomUniform(
ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min));
// In ScaleOnly mode, set Z as zero for generating dequantized A or B
const float zero_max = has_zero ? 2.0f : 0.0f;
const float zero_min = has_zero ? -2.0f : 0.0f;
cutlass::reference::device::BlockFillRandomUniform(
ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min));
} // End of "if (arguments->generate_scale_and_zero)"
// 2. Generate the dequantized A or B for verification
if (arguments->generate_dequantized_AB) {
StrideS stride_SZ = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
if constexpr(is_StrideB_Layout) {
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout of B later
using ActualLayoutB = cutlass::layout::ColumnMajor;
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
else {
using ActualStrideB = typename CollectiveMainloop::StrideB;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
}
else {
if constexpr(is_StrideA_Layout) {
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout of A later
using ActualLayoutA = cutlass::layout::RowMajor;
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
else {
using ActualStrideA = typename CollectiveMainloop::StrideA;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
} // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)"
arguments->dequantized_AB_ready[0] = true;
} // End of "if (arguments->generate_dequantized_AB)"
// 3. Put arguments in mainloop
if constexpr(is_int4_x_fp8) {
operator_args.mainloop.ptr_S = static_cast<ElementScaleMainloop const*>(arguments->packed_Scale);
}
else {
operator_args.mainloop.ptr_S = static_cast<ElementScale const*>(arguments->Scale);
}
operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
operator_args.mainloop.group_size = options_g;
if constexpr(has_zero) {
operator_args.mainloop.ptr_Z = static_cast<ElementZero const*>(arguments->Zero);
}
} // End of "if constexpr(has_scale)"
// Handle the shuffling
using ValueShuffle = std::conditional_t<
cutlass::sizeof_bits<ElementNarrow>::value == 4,
cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>,
cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>
>;
constexpr int NumShuffleAtoms = 1;
using MmaAtomShape = cute::Layout<cute::Shape<cute::_1,cute::Int<NumShuffleAtoms>>>;
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<ElementWide, MmaAtomShape, ValueShuffle>());
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout and stride of A/B later
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) {
using ActualLayoutB = cutlass::layout::ColumnMajor;
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
using LayoutB_Reordered = typename CollectiveMainloop::StrideB;
handle_shuffle_tensor_<ElementB, ActualStrideB, LayoutB_Reordered, LayoutAtomQuant, wider_operand>(
operator_args, arguments, problem_n, problem_k, options_l);
}
if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) {
using ActualLayoutA = cutlass::layout::RowMajor;
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
using LayoutA_Reordered = typename CollectiveMainloop::StrideA;
handle_shuffle_tensor_<ElementA, ActualStrideA, LayoutA_Reordered, LayoutAtomQuant, wider_operand>(
operator_args, arguments, problem_m, problem_k, options_l);
}
} // End of "if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{}))"
/* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */
operator_args.hw_info.sm_count = arguments->sm_count;
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
dim3 cluster_dims(cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
operator_args.hw_info.max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims, threads_per_block, kernel_ptr);
operator_args.hw_info.max_active_clusters = max_active_clusters;
}
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
@ -356,7 +638,10 @@ public:
return status;
}
return Operator::can_implement(args);
Status can_impl = Operator::can_implement(args);
//return Operator::can_implement(args);
return can_impl;
}
/// Gets the host-side workspace
@ -397,7 +682,7 @@ public:
cudaStream_t stream = nullptr) const override {
OperatorArguments args;
Status status = update_arguments_(args, static_cast<GemmUniversalArguments const *>(arguments_ptr));
Status status = update_arguments_(args, static_cast<GemmUniversalArguments const *>(arguments_ptr), stream);
if (status != Status::kSuccess) {
return status;
}

View File

@ -0,0 +1,330 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines operations for all grouped GEMM operations in CUTLASS Library.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "gemm_operation_3x.hpp"
#include "library_internal.h"
#include <unordered_map>
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::library {
/// **** CAUTION ****
/// Unlike other operations, initialize() must be called when
/// certain arguments change. See initialize() for details.
template <typename Operator_>
class GroupedGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using CollectiveMainloop = typename Operator::CollectiveMainloop;
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
private:
mutable CudaBuffer strideA_device;
mutable CudaBuffer strideB_device;
mutable CudaBuffer strideC_device;
mutable CudaBuffer strideD_device;
mutable std::vector<typename Operator::GemmKernel::InternalStrideA> strideA_host;
mutable std::vector<typename Operator::GemmKernel::InternalStrideB> strideB_host;
mutable std::vector<typename Operator::GemmKernel::InternalStrideC> strideC_host;
mutable std::vector<typename Operator::GemmKernel::InternalStrideD> strideD_host;
public:
GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm")
: GemmOperation3xBase<Operator_>(name, GemmKind::kGrouped) {
this->description_.kind = OperationKind::kGroupedGemm;
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
}
~GroupedGemmUniversal3xOperation() override = default;
private:
int max_active_clusters{};
protected:
template <class FusionArgs, class = void> struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template <class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) {
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
fusion_args.alpha = *static_cast<ElementCompute const*>(arguments.alpha);
fusion_args.beta = *static_cast<ElementCompute const*>(arguments.beta);
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
return Status::kSuccess;
}
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(arguments.alpha);
fusion_args.beta_ptr = static_cast<ElementCompute const*>(arguments.beta);
return Status::kSuccess;
}
else {
return Status::kErrorInvalidProblem;
}
}
};
/// Constructs the arguments structure given the configuration and arguments
Status
update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const {
Status status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread,
*arguments);
if (status != Status::kSuccess) {
return status;
}
operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped;
operator_args.problem_shape = {
arguments->problem_count,
arguments->problem_sizes_3x,
arguments->pointer_mode == ScalarPointerMode::kHost ? arguments->problem_sizes_3x_host
: nullptr};
operator_args.mainloop.ptr_A =
static_cast<const typename Operator::ElementA**>(arguments->ptr_A);
operator_args.mainloop.ptr_B =
static_cast<const typename Operator::ElementB**>(arguments->ptr_B);
operator_args.epilogue.ptr_C =
static_cast<const typename Operator::ElementC**>(arguments->ptr_C);
operator_args.epilogue.ptr_D = static_cast<typename Operator::ElementD**>(arguments->ptr_D);
operator_args.mainloop.dA =
static_cast<typename Operator::GemmKernel::InternalStrideA*>(strideA_device.data());
operator_args.mainloop.dB =
static_cast<typename Operator::GemmKernel::InternalStrideB*>(strideB_device.data());
operator_args.epilogue.dC =
static_cast<typename Operator::GemmKernel::InternalStrideC*>(strideC_device.data());
operator_args.epilogue.dD =
static_cast<typename Operator::GemmKernel::InternalStrideD*>(strideD_device.data());
operator_args.hw_info.sm_count = arguments->sm_count;
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
operator_args.hw_info.max_active_clusters = max_active_clusters;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape = dim3(
arguments->cluster_shape.m(),
arguments->cluster_shape.n(),
arguments->cluster_shape.k());
operator_args.hw_info.cluster_shape_fallback = dim3(
arguments->cluster_shape_fallback.m(),
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return status;
}
public:
/// Returns success if the operation can proceed
Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr)
const override {
GemmGroupedArguments const* arguments = static_cast<GemmGroupedArguments const*>(arguments_ptr);
OperatorArguments args;
auto status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
status = Operator::can_implement(args);
return status;
}
/// Gets the host-side workspace
uint64_t get_host_workspace_size(void const* configuration) const override {
return sizeof(Operator);
}
/// Gets the device-side workspace
uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr)
const override {
OperatorArguments args;
auto status = update_arguments_(args, static_cast<GemmGroupedArguments const*>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
/// **** CAUTION ****
/// Must be called when lda, ldb, ldc, or ldd change.
/// The CUTLASS library stores the operations in a type-
/// erased manifest. Therefore, only this class knows
/// the type of strideA, strideB, strideC, and strideD.
/// Since grouped GEMM needs to allocate storage for
/// the strides on device, the concrete type of the stride
/// must be known in order to copy in the correct memory
/// layout on device.
Status initialize(
void const* configuration_ptr,
void* host_workspace,
void* device_workspace,
cudaStream_t stream = nullptr) const override {
auto const& config = *static_cast<GemmGroupedConfiguration const*>(configuration_ptr);
auto num_groups = config.problem_count;
strideA_device =
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups);
strideB_device =
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups);
strideC_device =
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups);
strideD_device =
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups);
strideA_host.resize(num_groups);
strideB_host.resize(num_groups);
strideC_host.resize(num_groups);
strideD_host.resize(num_groups);
for (int group_idx = 0; group_idx < num_groups; group_idx++) {
strideA_host[group_idx] =
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideA>(
config.lda[group_idx]);
strideB_host[group_idx] =
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideB>(
config.ldb[group_idx]);
strideC_host[group_idx] =
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideC>(
config.ldc[group_idx]);
strideD_host[group_idx] =
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideD>(
config.ldc[group_idx]);
}
CUDA_CHECK(cudaMemcpy(
strideA_device.data(),
strideA_host.data(),
sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(
strideB_device.data(),
strideB_host.data(),
sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(
strideC_device.data(),
strideC_host.data(),
sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(
strideD_device.data(),
strideD_host.data(),
sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups,
cudaMemcpyHostToDevice));
Operator* op = new (host_workspace) Operator;
return Status::kSuccess;
}
/// **** CAUTION ****
/// initialize() must be called if lda, ldb, ldc, or ldd change.
Status run(
void const* arguments_ptr,
void* host_workspace,
void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments operator_args;
auto const& args = *static_cast<GemmGroupedArguments const*>(arguments_ptr);
Status status = update_arguments_(operator_args, &args);
if (status != Status::kSuccess) {
return status;
}
Operator* op = static_cast<Operator*>(host_workspace);
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::library
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -64,7 +64,6 @@ void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
@ -114,7 +113,6 @@ void initialize_reference_operations(Manifest &manifest) {
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -37,6 +37,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride
@ -45,14 +46,6 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
#define CUDA_CHECK(cuda_error) \
{ \
if (cuda_error != cudaSuccess) { \
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
return Status::kInvalid; \
} \
}
namespace cutlass::library {
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -330,18 +330,18 @@ static struct {
char const *text;
char const *pretty;
OperationKind enumerant;
}
OperationKind_enumerants[] = {
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
} OperationKind_enumerants[] = {
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
{"gemm", "Gemm", OperationKind::kGemm},
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
{"rank_k", "RankK", OperationKind::kRankK},
{"rank_2k", "Rank2K", OperationKind::kRank2K},
{"trmm", "Trmm", OperationKind::kTrmm},
{"symm", "Symm", OperationKind::kSymm},
{"conv2d", "Conv2d", OperationKind::kConv2d},
{"conv3d", "Conv3d", OperationKind::kConv3d},
{"conv2d", "Conv2d", OperationKind::kConv2d},
{"conv3d", "Conv3d", OperationKind::kConv3d},
{"spgemm", "SparseGemm", OperationKind::kSparseGemm},
{"grouped_gemm", "GroupedGemm", OperationKind::kGroupedGemm},
};
/// Converts a Status enumerant to a string
@ -504,7 +504,6 @@ NumericTypeID_enumerants[] = {
{"fe2m1", "FE2M1", NumericTypeID::kFE2M1},
{"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0},
{"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3},
{"f16", "F16", NumericTypeID::kF16},
{"bf16", "BF16", NumericTypeID::kBF16},
{"f32", "F32", NumericTypeID::kF32},
@ -577,7 +576,6 @@ int sizeof_bits(NumericTypeID type) {
case NumericTypeID::kFE2M1: return 4;
case NumericTypeID::kFUE8M0: return 8;
case NumericTypeID::kFUE4M3: return 8;
case NumericTypeID::kF16: return 16;
case NumericTypeID::kBF16: return 16;
case NumericTypeID::kTF32: return 32;
@ -666,7 +664,6 @@ bool is_signed_type(NumericTypeID type) {
case NumericTypeID::kFE2M1: return true;
case NumericTypeID::kFUE8M0: return false;
case NumericTypeID::kFUE4M3: return false;
case NumericTypeID::kF16: return true;
case NumericTypeID::kBF16: return true;
case NumericTypeID::kTF32: return true;
@ -707,7 +704,6 @@ bool is_float_type(NumericTypeID type) {
case NumericTypeID::kFE2M1: return true;
case NumericTypeID::kFUE8M0: return true;
case NumericTypeID::kFUE4M3: return true;
case NumericTypeID::kF16: return true;
case NumericTypeID::kBF16: return true;
case NumericTypeID::kTF32: return true;
@ -1256,7 +1252,6 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(tmp);
}
break;
case NumericTypeID::kFE2M3:
{
float tmp;
@ -1292,7 +1287,6 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(tmp);
}
break;
case NumericTypeID::kF16:
{
float tmp;
@ -1473,7 +1467,6 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
ss << tmp;
}
break;
case NumericTypeID::kF16:
{
float tmp = *reinterpret_cast<half_t *>(bytes.data());
@ -1652,7 +1645,6 @@ bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t sr
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
@ -1789,7 +1781,6 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
@ -1927,7 +1918,6 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));

View File

@ -46,7 +46,8 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
src/problem_space.cpp
src/operation_profiler.cu
src/gemm_operation_profiler.cu
src/block_scaled_gemm_operation_profiler.cu
src/grouped_gemm_operation_profiler.cu
src/block_scaled_gemm_operation_profiler.cu
src/rank_k_operation_profiler.cu
src/rank_2k_operation_profiler.cu
src/trmm_operation_profiler.cu
@ -112,6 +113,7 @@ set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_K --operation=RankK --pro
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_2K --operation=Rank2K --providers=cutlass --verification-providers=cublas --junit-output=test_cutlass_profiler_rank_2k --print-kernel-before-running=true)
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_TRMM --operation=Trmm --providers=cutlass --verification-providers=device,host --junit-output=test_cutlass_profiler_trmm --print-kernel-before-running=true)
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_SYMM --operation=Symm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_symm --print-kernel-before-running=true)
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GROUPED_GEMM --operation=GroupedGemm --providers=cutlass --verification-providers=device --junit-output=test_cutlass_profiler_grouped_gemm --print-kernel-before-running=true)
cutlass_add_executable_tests(
test_profiler cutlass_profiler
@ -125,6 +127,7 @@ cutlass_add_executable_tests(
RANK_2K
TRMM
SYMM
GROUPED_GEMM
TEST_COMMAND_OPTIONS_PREFIX
CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_
DISABLE_EXECUTABLE_INSTALL_RULE

View File

@ -108,6 +108,8 @@ public:
bool use_pdl{false};
bool enable_sm90_mixed_dtype_shuffle_test{false};
//
// Methods
//
@ -160,6 +162,13 @@ public:
/// Buffer used for the cutlass reduction operations' host workspace
std::vector<uint8_t> reduction_host_workspace;
/// For mixed input dtype kernels
DeviceAllocation *Scale{nullptr}; // Scale tensor
DeviceAllocation *Zero{nullptr}; // Zero tensor
DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification
DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle
DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8
cudaStream_t stream;
};

View File

@ -0,0 +1,263 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief GroupedGemm Profiler
*/
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
// CUTLASS Library includes
#include "cutlass/library/library.h"
// Profiler includes
#include "device_context.h"
#include "operation_profiler.h"
#include "options.h"
#include "performance_result.h"
#include "problem_space.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Abstract base class for each math function
class GroupedGemmOperationProfiler : public OperationProfiler {
public:
/// Problem structure obtained from problem space
struct GroupedGemmProblem {
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped};
std::vector<gemm::GemmCoord> problem_sizes;
std::vector<cute::Shape<int, int, int>> problem_sizes_3x;
int cluster_m{1};
int cluster_n{1};
int cluster_k{1};
int cluster_m_fallback{1};
int cluster_n_fallback{1};
int cluster_k_fallback{1};
std::vector<int64_t> lda{0};
std::vector<int64_t> ldb{0};
std::vector<int64_t> ldc{0};
std::vector<uint8_t> alpha;
std::vector<uint8_t> beta;
/// Parses the problem
Status parse(
library::GemmDescription const& operation_desc,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); };
int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); };
int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); };
/// Total number of bytes loaded
int64_t bytes(library::GemmDescription const& operation_desc) const;
/// Total number of flops computed
int64_t flops(library::GemmDescription const& operation_desc) const;
/// Initializes a performance result
void initialize_result(
PerformanceResult& result,
library::GemmDescription const& operation_desc,
ProblemSpace const& problem_space);
};
// workspace contains the allocated blocks, arguments just contain the raw
// pointers
struct GroupedGemmWorkspace {
std::vector<DeviceAllocation*> A_ptr_array_device;
std::vector<DeviceAllocation*> B_ptr_array_device;
std::vector<DeviceAllocation*> C_ptr_array_device;
std::vector<DeviceAllocation*> D_ptr_array_device;
std::vector<DeviceAllocation*> reference_ptr_array_host;
std::vector<DeviceAllocation*> A_ptr_array_host;
std::vector<DeviceAllocation*> B_ptr_array_host;
std::vector<DeviceAllocation*> C_ptr_array_host;
std::vector<DeviceAllocation*> D_ptr_array_host;
/// Number of copies of the problem workspace which are visited sequentially during
/// profiling to avoid camping in the last level cache.
/// *NOT* the number of groups in the grouped GEMM
int problem_count{1};
DeviceAllocation* problem_sizes_array_device{nullptr};
DeviceAllocation* problem_sizes_3x_array_device{nullptr};
DeviceAllocation* lda_array_device{nullptr};
DeviceAllocation* ldb_array_device{nullptr};
DeviceAllocation* ldc_array_device{nullptr};
DeviceAllocation* ldd_array_device{nullptr};
library::GemmGroupedConfiguration configuration;
library::GemmGroupedArguments arguments;
std::vector<uint8_t> host_workspace;
DeviceAllocation device_workspace;
};
private:
void init_arguments(Options const& options) {
gemm_workspace_.arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data();
gemm_workspace_.arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data();
gemm_workspace_.arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data();
gemm_workspace_.arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data();
gemm_workspace_.arguments.alpha = problem_.alpha.data();
gemm_workspace_.arguments.beta = problem_.beta.data();
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
gemm_workspace_.arguments.lda = static_cast<int64_t*>(gemm_workspace_.lda_array_device->data());
gemm_workspace_.arguments.ldb = static_cast<int64_t*>(gemm_workspace_.ldb_array_device->data());
gemm_workspace_.arguments.ldc = static_cast<int64_t*>(gemm_workspace_.ldc_array_device->data());
gemm_workspace_.arguments.ldd = static_cast<int64_t*>(gemm_workspace_.ldc_array_device->data());
gemm_workspace_.arguments.problem_sizes =
static_cast<gemm::GemmCoord*>(gemm_workspace_.problem_sizes_array_device->data());
gemm_workspace_.arguments.problem_sizes_3x = static_cast<cute::Shape<int, int, int>*>(
gemm_workspace_.problem_sizes_3x_array_device->data());
gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data();
gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size();
gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
}
protected:
/// GEMM problem obtained from problem space
GroupedGemmProblem problem_;
/// Device memory allocations
GroupedGemmWorkspace gemm_workspace_;
public:
GroupedGemmOperationProfiler(Options const& options);
virtual ~GroupedGemmOperationProfiler();
GroupedGemmProblem const& problem() const { return problem_; }
/// Prints usage statement for the math function
virtual void print_usage(std::ostream& out) const;
/// Prints examples
virtual void print_examples(std::ostream& out) const;
/// Extracts the problem dimensions
virtual Status initialize_configuration(
Options const& options,
PerformanceReport& report,
DeviceContext& device_context,
library::Operation const* operation,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
/// Initializes workspace
virtual Status initialize_workspace(
Options const& options,
PerformanceReport& report,
DeviceContext& device_context,
library::Operation const* operation,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
/// Verifies CUTLASS against references
virtual bool verify_cutlass(
Options const& options,
PerformanceReport& report,
DeviceContext& device_context,
library::Operation const* operation,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
/// Measures performance results
virtual bool profile(
Options const& options,
PerformanceReport& report,
DeviceContext& device_context,
library::Operation const* operation,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
protected:
/// Initializes the performance result
void initialize_result_(
PerformanceResult& result,
Options const& options,
library::GemmDescription const& operation_desc,
ProblemSpace const& problem_space);
/// Verifies CUTLASS against host and device references
bool verify_with_reference_(
Options const& options,
PerformanceReport& report,
DeviceContext& device_context,
library::Operation const* operation,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem,
cutlass::library::NumericTypeID element_A,
cutlass::library::NumericTypeID element_B);
/// Method to profile a CUTLASS Operation
Status profile_cutlass_(
PerformanceResult& result,
Options const& options,
library::Operation const* operation,
void* arguments,
void* host_workspace,
void* device_workspace) override;
/// Initialize reduction problem dimensions and library::Operation
bool initialize_reduction_configuration_(
library::Operation const* operation,
ProblemSpace::Problem const& problem);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -241,17 +241,30 @@ protected:
/// Profiles the GPU kernel launched in `func` running simultaneously on all
/// requested devices.
Status profile_kernel_w_cuda_graphs_(
PerformanceResult& result,
Options const& options,
std::function<Status(int, cudaStream_t, int)> const& func,
std::vector<cudaStream_t> const& streams);
Status profile_kernel_(
PerformanceResult &result,
Options const &options,
const std::function<Status(int, cudaStream_t, int)> &func,
const std::vector<cudaStream_t> &streams);
PerformanceResult& result,
Options const& options,
std::function<Status(int, cudaStream_t, int)> const& func,
std::vector<cudaStream_t> const& streams);
/// Profiles the GPU kernel launched in `func` on the `stream`
Status profile_kernel_(
PerformanceResult &result,
Options const &options,
const std::function<Status(cudaStream_t, int)> &func,
PerformanceResult& result,
Options const& options,
std::function<Status(cudaStream_t, int)> const& func,
cudaStream_t stream = nullptr);
/// Profiles the GPU kernel launched in `func` on the `stream`
Status profile_kernel_no_cuda_graphs_(
PerformanceResult& result,
Options const& options,
std::function<Status(cudaStream_t, int)> const& func,
cudaStream_t stream = nullptr);
private:

View File

@ -208,6 +208,8 @@ public:
/// Minimum number of iterations to profile
int min_iterations{10};
bool use_cuda_graphs{false};
/// Number of ms to sleep between profiling periods (ms)
int sleep_duration{50};

View File

@ -988,6 +988,12 @@ bool arg_as_scalar(
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
bool arg_as_string(
std::string& arg,
char const* name,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem);
/// Returns true if a tensor description satisfies a `tensor` value
bool tensor_description_satisfies(
library::TensorDescription const &tensor_desc,

View File

@ -75,7 +75,6 @@ BlockScaledGemmOperationProfiler::BlockScaledGemmOperationProfiler(Options const
{ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D output"},
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
// TODO: Bring these back once SM100 future audits are complete
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
@ -113,14 +112,11 @@ void BlockScaledGemmOperationProfiler::print_examples(std::ostream &out) const {
<< "Schmoo over problem size and beta:\n"
<< " $ cutlass_profiler --operation=block_scaled_gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n"
// TODO: Bring these back once SM100 future audits are complete
#if 0
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
<< "For column major, use column, col, or n. For row major use, row or t:\n"
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
<< "Profile a particular problem size with split K and parallel reduction:\n"
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n"
#endif
<< "Using various input value distribution:\n"
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
@ -225,7 +221,6 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse(
this->split_k_slices = 1;
}
// TODO: Bring these back once SM100 future audits are complete
if (this->split_k_mode != library::SplitKMode::kSerial) {
std::cout<<"SplitK/StreamK feature is not supported yet!";
return Status::kErrorInvalidProblem;
@ -403,7 +398,6 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
// TODO: Bring these back once SM100 future audits are complete
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "split_k_slices", problem_space, split_k_slices);
set_argument(result, "batch_count", problem_space, batch_count);
@ -536,8 +530,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem) {
// TODO: Bring these back once SM100 future audits are complete
#if 1
library::BlockScaledGemmDescription const &gemm_desc =
static_cast<library::BlockScaledGemmDescription const&>(operation->description());
@ -577,8 +569,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_(
// reduction operation found and initialized
return true;
#endif
return false;
}
/// Initializes workspace

View File

@ -545,6 +545,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResult, &returnedResults);
if (returnedResults == 0) {
cudaFree(workspaceHeuristic);
return false;
}
@ -589,6 +590,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
// Handle errors
if (status != CUBLAS_STATUS_SUCCESS) {
std::cerr << "cublasLtMatmul AutoTuning failed with status: " << cublasLtGetStatusName(status) << std::endl;
cudaFree(workspaceHeuristic);
return false;
}
@ -653,6 +655,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
throw std::bad_alloc();
}
cudaFree(workspaceHeuristic);
return true;
}

View File

@ -36,16 +36,17 @@
#include <stdexcept>
// Profiler includes
#include "cutlass/profiler/cutlass_profiler.h"
#include "cutlass/profiler/gemm_operation_profiler.h"
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
#include "cutlass/profiler/rank_k_operation_profiler.h"
#include "cutlass/profiler/rank_2k_operation_profiler.h"
#include "cutlass/profiler/trmm_operation_profiler.h"
#include "cutlass/profiler/symm_operation_profiler.h"
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
#include "cutlass/profiler/conv2d_operation_profiler.h"
#include "cutlass/profiler/conv3d_operation_profiler.h"
#include "cutlass/profiler/cutlass_profiler.h"
#include "cutlass/profiler/gemm_operation_profiler.h"
#include "cutlass/profiler/grouped_gemm_operation_profiler.h"
#include "cutlass/profiler/rank_2k_operation_profiler.h"
#include "cutlass/profiler/rank_k_operation_profiler.h"
#include "cutlass/profiler/sparse_gemm_operation_profiler.h"
#include "cutlass/profiler/symm_operation_profiler.h"
#include "cutlass/profiler/trmm_operation_profiler.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -76,6 +77,8 @@ CutlassProfiler::CutlassProfiler(
operation_profilers_.emplace_back(new TrmmOperationProfiler(options));
operation_profilers_.emplace_back(new SymmOperationProfiler(options));
operation_profilers_.emplace_back(new GroupedGemmOperationProfiler(options));
}
CutlassProfiler::~CutlassProfiler() {
@ -201,6 +204,7 @@ void CutlassProfiler::print_usage_(std::ostream &out) {
<< " $ cutlass_profiler --operation=Conv3d --help\n\n"
<< " $ cutlass_profiler --operation=Conv2d --help\n\n"
<< " $ cutlass_profiler --operation=SparseGemm --help\n\n"
<< " $ cutlass_profiler --operation=GroupedGemm --help\n\n"
;
}

View File

@ -616,7 +616,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kFUE4M3:
cutlass::reference::device::BlockFillRandom<cutlass::float_ue4m3_t>(
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
@ -657,7 +656,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::device::BlockFillRandom<double>(
reinterpret_cast<double *>(pointer_),
@ -823,7 +821,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
@ -856,7 +853,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(host_data.data()),
@ -1086,7 +1082,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
@ -1119,7 +1114,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::device::BlockFillSequential<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(pointer_),
@ -1360,7 +1354,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
@ -1393,7 +1386,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFillSequential<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(host_data.data()),
@ -1690,7 +1682,6 @@ bool DeviceAllocation::block_compare_equal(
reinterpret_cast<float_e5m2_t const *>(ptr_A),
reinterpret_cast<float_e5m2_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFUE4M3:
return reference::device::BlockCompareEqual<float_ue4m3_t>(
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
@ -1717,7 +1708,6 @@ bool DeviceAllocation::block_compare_equal(
reinterpret_cast<float_e2m1_t const *>(ptr_A),
reinterpret_cast<float_e2m1_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kF16:
return reference::device::BlockCompareEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -1886,7 +1876,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
capacity,
static_cast<float_e5m2_t>(epsilon),
static_cast<float_e5m2_t>(nonzero_floor));
case library::NumericTypeID::kFUE4M3:
return reference::device::BlockCompareRelativelyEqual<float_ue4m3_t>(
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
@ -1925,7 +1914,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
capacity,
static_cast<float_e2m1_t>(epsilon),
static_cast<float_e2m1_t>(nonzero_floor));
case library::NumericTypeID::kF16:
return reference::device::BlockCompareRelativelyEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -2273,7 +2261,6 @@ void DeviceAllocation::write_tensor_csv(
write_tensor_csv_static_type<float_ue4m3_t>(out, *this);
break;
case library::NumericTypeID::kFE2M3:
write_tensor_csv_static_type<float_e2m3_t>(out, *this);
break;
@ -2288,7 +2275,6 @@ void DeviceAllocation::write_tensor_csv(
case library::NumericTypeID::kFUE8M0:
write_tensor_csv_static_type<float_ue8m0_t>(out, *this);
break;
case library::NumericTypeID::kF16:
write_tensor_csv_static_type<half_t>(out, *this);
break;
@ -2475,7 +2461,6 @@ void DeviceAllocation::fill_device(double val = 0.0) {
case library::NumericTypeID::kFE2M1:
tensor_fill<float_e2m1_t>(*this, static_cast<float_e2m1_t>(val));
break;
case library::NumericTypeID::kF16:
tensor_fill<half_t>(*this, static_cast<half_t>(val));
@ -2611,7 +2596,6 @@ void DeviceAllocation::fill_host(double val = 0.0) {
static_cast<float_e2m1_t>(val)
);
break;
case library::NumericTypeID::kFE4M3:
cutlass::reference::host::BlockFill<float_e4m3_t>(

View File

@ -75,6 +75,35 @@ DeviceAllocation *DeviceContext::allocate_tensor(
return allocation;
}
static void initialize_allocation_with_data_distribution(
Options const &options,
int seed_shift,
DeviceAllocation *allocation,
Distribution &data_distribution) {
if (options.initialization.provider == library::Provider::kReferenceDevice) {
if (data_distribution.kind == Distribution::Sequential) {
allocation->initialize_sequential_device(
data_distribution);
}
else {
allocation->initialize_random_device(
options.initialization.seed + seed_shift,
data_distribution);
}
}
else if (options.initialization.provider == library::Provider::kReferenceHost) {
if (data_distribution.kind == Distribution::Sequential) {
allocation->initialize_sequential_host(
data_distribution);
}
else {
allocation->initialize_random_host(
options.initialization.seed + seed_shift,
data_distribution);
}
}
}
/// Allocates memory of a given type, capacity (elements), and name
DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
Options const &options,
@ -122,7 +151,6 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
data_distribution.set_uniform(1, 4, 0);
break;
case library::NumericTypeID::kF16:
data_distribution.set_uniform(-3, 3, 0);
break;
@ -168,28 +196,9 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
}
}
if (options.initialization.provider == library::Provider::kReferenceDevice) {
if (data_distribution.kind == Distribution::Sequential) {
allocation->initialize_sequential_device(
data_distribution);
}
else {
allocation->initialize_random_device(
options.initialization.seed + seed_shift,
data_distribution);
}
}
else if (options.initialization.provider == library::Provider::kReferenceHost) {
if (data_distribution.kind == Distribution::Sequential) {
allocation->initialize_sequential_host(
data_distribution);
}
else {
allocation->initialize_random_host(
options.initialization.seed + seed_shift,
data_distribution);
}
}
initialize_allocation_with_data_distribution(
options, seed_shift, allocation, data_distribution
);
}
return allocation;

View File

@ -79,6 +79,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"},
{ArgumentTypeID::kEnumerated, {"enable_sm90_mixed_dtype_shuffle_test", "enable-sm90-mixed-dtype-shuffle-test"}, "Enable SM90 mixed input data type kernel shuffle layout test (true, false)"},
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
},
{ library::Provider::kCUBLAS}
@ -211,6 +212,11 @@ Status GemmOperationProfiler::GemmProblem::parse(
this->use_pdl = false;
}
if (!arg_as_bool(this->enable_sm90_mixed_dtype_shuffle_test, "enable_sm90_mixed_dtype_shuffle_test", problem_space, problem)) {
// default value
this->enable_sm90_mixed_dtype_shuffle_test = false;
}
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
// default value
this->split_k_mode = library::SplitKMode::kSerial;
@ -399,6 +405,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
set_argument(result, "swizzle_size", problem_space, swizzle_size);
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
set_argument(result, "enable_sm90_mixed_dtype_shuffle_test", problem_space, library::to_string(enable_sm90_mixed_dtype_shuffle_test));
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
@ -432,14 +439,26 @@ Status GemmOperationProfiler::initialize_configuration(
Status status = problem_.parse(operation_desc, problem_space, problem);
// Note: this is a temporary workaround
bool is_current_operation_sm90_mixed_dtype_shuffle = (strstr(operation_desc.name, "_shfl") != NULL);
if (is_current_operation_sm90_mixed_dtype_shuffle && (problem_.enable_sm90_mixed_dtype_shuffle_test == false)) {
return Status::kErrorInvalidProblem;
}
if (status != Status::kSuccess) {
return status;
}
const auto device_count = options.device.devices.size();
auto const device_count = options.device.devices.size();
gemm_workspace_.clear();
library::NumericTypeID a_elem = library::get_real_type(operation_desc.A.element);
library::NumericTypeID b_elem = library::get_real_type(operation_desc.B.element);
int a_elem_bits = library::sizeof_bits(a_elem);
int b_elem_bits = library::sizeof_bits(b_elem);
bool is_mixed_input = (a_elem_bits != b_elem_bits);
for (size_t i = 0; i < device_count; ++i) {
cudaSetDevice(options.device.device_id(i));
gemm_workspace_.emplace_back();
@ -455,7 +474,6 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback);
gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback);
gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback);
gemm_workspace_[i].configuration.lda = problem_.lda;
gemm_workspace_[i].configuration.ldb = problem_.ldb;
gemm_workspace_[i].configuration.ldc = problem_.ldc;
@ -501,7 +519,77 @@ Status GemmOperationProfiler::initialize_configuration(
initialize_result_(this->model_result_, options, operation_desc, problem_space);
if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) {
if (is_mixed_input)
{
const int options_g = problem_.k;
const int options_l = problem_.batch_count;
const int scale_k = (problem_.k + options_g - 1) / options_g;
// We cannot get the mainloop's ElementScale and ElementZero here,
// use the wide type to allocate a large enough workspace for S and Z.
library::NumericTypeID wide_dtype;
size_t SZ_mat_size = 0;
if (a_elem_bits > b_elem_bits) {
wide_dtype = a_elem;
SZ_mat_size = static_cast<size_t>(problem_.n * scale_k);
}
else {
wide_dtype = b_elem;
SZ_mat_size = static_cast<size_t>(problem_.m * scale_k);
}
gemm_workspace_[i].Scale = device_context.allocate_tensor(
options,
"Scale",
wide_dtype,
library::LayoutTypeID::kRowMajor,
{int(SZ_mat_size), int(options_l)},
{int(options_l)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
gemm_workspace_[i].Zero = device_context.allocate_tensor(
options,
"Zero",
wide_dtype,
library::LayoutTypeID::kRowMajor,
{int(SZ_mat_size), int(options_l)},
{int(options_l)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
// Packed scale is for int4 * fp8, where the original scale is fp8, and
// each scale element will be packed into an Array<fp8, 8> which is 64-bit
gemm_workspace_[i].packed_Scale = device_context.allocate_tensor(
options,
"packed-Scale",
library::NumericTypeID::kU64,
library::LayoutTypeID::kRowMajor,
{int(SZ_mat_size), int(options_l)},
{int(options_l)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)};
gemm_workspace_[i].arguments.batch_count = problem_.batch_count;
// Here is the first touch of the arguments, mark the mixed dtype,
// populate the scale and zero tensors in the following can_implement() call later.
// A and B are not populated at this moment, so do not update the dequantized A or B
gemm_workspace_[i].arguments.is_mixed_dtype = true;
gemm_workspace_[i].arguments.wider_operand = (a_elem_bits > b_elem_bits) ? cutlass::library::Sm90MixedInputWiderOperand::A : cutlass::library::Sm90MixedInputWiderOperand::B;
gemm_workspace_[i].arguments.generate_scale_and_zero = true;
gemm_workspace_[i].arguments.generate_dequantized_AB = false;
gemm_workspace_[i].arguments.dequantized_AB_ready = (bool *) malloc(sizeof(bool));
gemm_workspace_[i].arguments.dequantized_AB_ready[0] = false;
gemm_workspace_[i].arguments.Scale = gemm_workspace_[i].Scale->data();
gemm_workspace_[i].arguments.Zero = gemm_workspace_[i].Zero->data();
gemm_workspace_[i].arguments.packed_Scale = gemm_workspace_[i].packed_Scale->data();
} // End of "if (is_mixed_input)"
const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments);
if (can_implement != Status::kSuccess) {
return can_implement;
}
}
@ -693,6 +781,56 @@ Status GemmOperationProfiler::initialize_workspace(
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
if (gemm_workspace_[i].arguments.is_mixed_dtype) {
// Dequantized tensor has the same shape of the narrow data type tensor,
// and the same data type as the wide data type tensor
// Encoded tensor has the same shape and data type of the narrow data type tensor
if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) {
gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor(
options,
"dequantized-B",
operation_desc.A.element,
operation_desc.B.layout,
{int(problem_.k), int(problem_.n)},
{int(problem_.ldb)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
gemm_workspace_[i].encoded_AB = device_context.allocate_tensor(
options,
"encoded-B",
operation_desc.B.element,
operation_desc.B.layout,
{int(problem_.k), int(problem_.n)},
{int(problem_.ldb)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
}
else {
gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor(
options,
"dequantized-A",
operation_desc.B.element,
operation_desc.A.layout,
{int(problem_.m), int(problem_.k)},
{int(problem_.lda)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
gemm_workspace_[i].encoded_AB = device_context.allocate_tensor(
options,
"encoded-A",
operation_desc.A.element,
operation_desc.A.layout,
{int(problem_.m), int(problem_.k)},
{int(problem_.lda)},
problem_.batch_count * gemm_workspace_[i].problem_count,
i // device_index
);
}
}
}
if (options.execution_mode != ExecutionMode::kDryRun) {
@ -712,7 +850,7 @@ Status GemmOperationProfiler::initialize_workspace(
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
/* Query device SM count to pass onto the kernel as an argument, where needed */
gemm_workspace_[i].arguments.sm_count = options.device.properties[0].multiProcessorCount;
gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount;
gemm_workspace_[i].arguments.device_index = static_cast<int>(i);
}
}
@ -836,6 +974,17 @@ bool GemmOperationProfiler::verify_cutlass(
gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride();
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
if (gemm_workspace_[i].arguments.is_mixed_dtype) {
// Scale and zero already generated in initialize_configuration(),
// A and B already generated in initialize_workspace(), signal
// GemmUniversal3xOperation::update_arguments_() (trigger by underlying_operation->run())
// to generate the dequantized matrix for verification
gemm_workspace_[i].arguments.generate_scale_and_zero = false;
gemm_workspace_[i].arguments.generate_dequantized_AB = true;
gemm_workspace_[i].arguments.dequantized_AB = gemm_workspace_[i].dequantized_AB->data();
gemm_workspace_[i].arguments.encoded_AB = gemm_workspace_[i].encoded_AB->data();
}
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data();
gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data();
@ -1133,7 +1282,6 @@ bool GemmOperationProfiler::verify_with_reference_(
//
// Initialize state
//
for (auto provider : options.verification.providers) {
// Skip providers that are not enabled
@ -1149,6 +1297,21 @@ bool GemmOperationProfiler::verify_with_reference_(
void *ptr_C = gemm_workspace_[i].C->data();
void *ptr_D = gemm_workspace_[i].Reference->data();
cutlass::library::NumericTypeID element_A_for_reference = element_A;
cutlass::library::NumericTypeID element_B_for_reference = element_B;
if (gemm_workspace_[i].arguments.is_mixed_dtype && gemm_workspace_[i].arguments.dequantized_AB_ready[0]) {
// Dequantized tensor has the same shape of the narrow data type tensor,
// and the same data type as the wide data type tensor
if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) {
ptr_B = gemm_workspace_[i].dequantized_AB->data();
element_B_for_reference = element_A;
}
else {
ptr_A = gemm_workspace_[i].dequantized_AB->data();
element_A_for_reference = element_B;
}
}
// To support the host-side reference, conditionally allocate and
// copy tensors to host memory.
std::vector<uint8_t> host_data_A;
@ -1200,13 +1363,13 @@ bool GemmOperationProfiler::verify_with_reference_(
problem_.alpha.data(),
element_A,
element_A_for_reference,
gemm_desc.A.layout,
gemm_desc.transform_A,
ptr_A,
int(gemm_workspace_[i].configuration.lda),
element_B,
element_B_for_reference,
gemm_desc.B.layout,
gemm_desc.transform_B,
ptr_B,
@ -1349,6 +1512,13 @@ Status GemmOperationProfiler::profile_cutlass_(
gemm_workspace_[dev_id].arguments.C = gemm_workspace_[dev_id].C->batch_data(problem_idx);
gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].Computed->batch_data(problem_idx);
if (gemm_workspace_[dev_id].arguments.is_mixed_dtype) {
// Scale, zero, and dequantized tensors are already generated in
// verify_cutlass(), no need to re-generate them in profiling
gemm_workspace_[dev_id].arguments.generate_scale_and_zero = false;
gemm_workspace_[dev_id].arguments.generate_dequantized_AB = false;
}
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].device_workspace.data();
@ -1383,11 +1553,6 @@ Status GemmOperationProfiler::profile_cutlass_(
return Status::kSuccess;
};
if (options.device.devices.size() == 1) {
auto func = [&](cudaStream_t stream, int iteration) { return launch_gemm(0, stream, iteration); };
return profile_kernel_(result, options, func, gemm_workspace_[0].stream);
}
std::vector<cudaStream_t> streams(gemm_workspace_.size());
for (size_t i = 0; i < streams.size(); i++) {
streams[i] = gemm_workspace_[i].stream;

File diff suppressed because it is too large Load Diff

View File

@ -57,16 +57,6 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " code=" << err << " \"" \
<< cudaGetErrorString(err) << "\"\n"; \
return Status::kErrorInternal; \
} \
} while (0)
namespace cutlass {
namespace profiler {
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -304,42 +294,43 @@ std::ostream& operator<<(std::ostream& out, library::Provider provider) {
return out;
}
std::ostream& operator<<(std::ostream& out, library::OperationKind provider) {
if (provider == library::OperationKind::kGemm) {
std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
if (op_kind == library::OperationKind::kGemm) {
out << "kGemm";
}
else if (provider == library::OperationKind::kBlockScaledGemm) {
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
out << "kBlockScaledGemm";
}
else if (provider == library::OperationKind::kRankK) {
else if (op_kind == library::OperationKind::kRankK) {
out << "kRankK";
}
else if (provider == library::OperationKind::kRank2K) {
else if (op_kind == library::OperationKind::kRank2K) {
out << "kRank2K";
}
else if (provider == library::OperationKind::kTrmm) {
else if (op_kind == library::OperationKind::kTrmm) {
out << "kTrmm";
}
else if (provider == library::OperationKind::kSymm) {
else if (op_kind == library::OperationKind::kSymm) {
out << "kSymm";
}
else if (provider == library::OperationKind::kConv2d) {
else if (op_kind == library::OperationKind::kConv2d) {
out << "kConv2d";
}
else if (provider == library::OperationKind::kConv3d) {
else if (op_kind == library::OperationKind::kConv3d) {
out << "kConv3d";
}
else if (provider == library::OperationKind::kEqGemm) {
else if (op_kind == library::OperationKind::kEqGemm) {
out << "kEqGemm";
}
else if (provider == library::OperationKind::kSparseGemm) {
else if (op_kind == library::OperationKind::kSparseGemm) {
out << "kSparseGemm";
}
else if (provider == library::OperationKind::kReduction) {
else if (op_kind == library::OperationKind::kReduction) {
out << "kReduction";
}
else if (op_kind == library::OperationKind::kGroupedGemm) {
out << "kGroupedGemm";
}
else {
out << "kInvalid";
}
@ -660,6 +651,11 @@ void OperationProfiler::save_workspace(
DeviceAllocation *allocation = named_allocation.second;
if (allocation->layout() == library::LayoutTypeID::kUnknown) {
continue; // write_tensor not set up to handle DeviceAllocations initialized using
// allocate_block()
}
std::stringstream filename;
filename << desc.name << "_" << library::to_string(provider) << "_";
@ -736,15 +732,20 @@ Status predict_iters(
/// CUDA graphs allows you to record the launch of large numbers of kernels without
/// blocking and therefore avoids a deadlock which happens if you try to enqueue too
/// many kernels behind the spinloop kernel.
Status OperationProfiler::profile_kernel_(
PerformanceResult &result,
Options const &options,
const std::function<Status(int, cudaStream_t, int)> &func,
const std::vector<cudaStream_t> &streams) {
Status OperationProfiler::profile_kernel_w_cuda_graphs_(
PerformanceResult& result,
Options const& options,
std::function<Status(int, cudaStream_t, int)> const& func,
std::vector<cudaStream_t> const& streams) {
auto dev_count = streams.size();
cuda::atomic<bool> *release;
CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable));
release->store(false, cuda::memory_order_release);
if (dev_count > 1) {
CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable));
release->store(false, cuda::memory_order_release);
}
std::vector<GpuTimer> timer;
for (size_t i = 0; i < dev_count; ++i) {
@ -774,9 +775,11 @@ Status OperationProfiler::profile_kernel_(
for (size_t i = 0; i < dev_count; ++i) {
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
CUDA_CHECK(cudaStreamBeginCapture(streams[i], cudaStreamCaptureModeGlobal));
// Halt execution until all GPUs are ready to precede.
// It allows the CPU to trigger the GPUs all start at the same time.
delay<<<1, 1, 0, streams[i]>>>(release);
if (dev_count > 1) {
// Halt execution until all GPUs are ready to precede.
// It allows the CPU to trigger the GPUs all start at the same time.
delay<<<1, 1, 0, streams[i]>>>(release);
}
for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) {
Status status = func(i, streams[i], iteration);
if (status != Status::kSuccess) {
@ -803,8 +806,10 @@ Status OperationProfiler::profile_kernel_(
CUDA_CHECK(cudaGraphLaunch(graphExecs[i], streams[i]));
}
// release the enqueued kernels
release->store(true, cuda::memory_order_release);
if (dev_count > 1) {
// release the enqueued kernels
release->store(true, cuda::memory_order_release);
}
for (size_t i = 0; i < dev_count; ++i) {
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
@ -819,7 +824,9 @@ Status OperationProfiler::profile_kernel_(
}
result.runtime /= static_cast<double>(dev_count);
CUDA_CHECK(cudaFreeHost(release));
if (dev_count > 1) {
CUDA_CHECK(cudaFreeHost(release));
}
for (size_t i = 0; i < dev_count; ++i) {
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
@ -835,11 +842,47 @@ Status OperationProfiler::profile_kernel_(
return Status::kSuccess;
}
/// Method to profile GPU execution time of a kernel launched in func
Status OperationProfiler::profile_kernel_(
PerformanceResult &result,
Options const &options,
const std::function<Status(cudaStream_t, int)> &func,
const std::function<Status(int, cudaStream_t, int)> &func,
const std::vector<cudaStream_t> &streams) {
if (options.profiling.use_cuda_graphs) {
return profile_kernel_w_cuda_graphs_(result, options, func, streams);
}
else if (streams.size() == 1) {
auto single_device_func = [&](cudaStream_t stream, int iteration) {
return func(0, stream, iteration);
};
return profile_kernel_no_cuda_graphs_(result, options, single_device_func, streams[0]);
}
return Status::kErrorNotSupported;
}
/// Method to profile GPU execution time of a kernel launched in func
Status OperationProfiler::profile_kernel_(
PerformanceResult& result,
Options const& options,
std::function<Status(cudaStream_t, int)> const& func,
cudaStream_t stream) {
if (options.profiling.use_cuda_graphs) {
auto graph_func = [&](int dev_id, cudaStream_t stream, int iteration) {
return func(stream, iteration);
};
return profile_kernel_w_cuda_graphs_(result, options, graph_func, {stream});
} else {
return profile_kernel_no_cuda_graphs_(result, options, func, stream);
}
return Status::kSuccess;
}
/// Method to profile GPU execution time of a kernel launched in func
Status OperationProfiler::profile_kernel_no_cuda_graphs_(
PerformanceResult& result,
Options const& options,
std::function<Status(cudaStream_t, int)> const& func,
cudaStream_t stream) {
GpuTimer timer;

View File

@ -477,6 +477,7 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) {
cmdline.get_cmd_line_argument("profiling-enabled", enabled, true);
cmdline.get_cmd_line_argument("profiling-duration", duration, 10);
cmdline.get_cmd_line_argument("min-iterations", min_iterations, 10);
cmdline.get_cmd_line_argument("use-cuda-graphs", use_cuda_graphs, false);
if (cmdline.check_cmd_line_flag("providers")) {

View File

@ -1203,6 +1203,34 @@ bool arg_as_scalar(
return arg_as_scalar(bytes, numeric_type, value_ptr);
}
/// Returns a copy of the string passed to the argument.
/// (kScalar arguments are stored as strings).
bool arg_as_string(
std::string& arg,
char const* name,
ProblemSpace const& problem_space,
ProblemSpace::Problem const& problem) {
size_t idx = problem_space.argument_index(name);
KernelArgument::Value const* value_ptr = problem.at(idx).get();
if (value_ptr->not_null) {
if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {
std::string const& str_value =
static_cast<ScalarArgument::ScalarValue const*>(value_ptr)->value;
arg = std::string(str_value);
}
else {
throw std::runtime_error(
"arg_as_string() - illegal cast. Problem space argument must be scalar");
}
return true;
}
return false;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns true if a tensor description satisfies a `tensor` value

View File

@ -56,9 +56,7 @@ template <typename T>
T* allocate(size_t count = 1) {
T* ptr = 0;
size_t bytes = 0;
bytes = count * sizeof(T);
size_t bytes = count * sizeof_bits<T>::value / 8;
cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);

View File

@ -0,0 +1,480 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Utilities for mixed input data type kernels.
*/
#pragma once
#include <cuda.h>
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/mma_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cute/util/type_traits.hpp"
namespace cutlass {
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
template <
class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleBroadCastLayout,
class ThrLayout>
__global__ void dequantize_kernel(DequantizedElement* dq_buffer,
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleBroadCastLayout const broadcasted_scale_layout,
ThrLayout thr_layout) {
using namespace cute;
// Represent the full tensors to gmem elements.
// These are expected to have shape [MN, K, L]
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
auto init_quantized_iterator = [&]() {
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
return cute::make_gmem_ptr(q_buffer);
}
else {
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
}
};
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
// It is expected that K % G == 0
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
// Assign 1 thread per element in the thread block
auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
// Tile across the block
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
// Make a fragment of registers to hold gmem loads
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
for (int ii = 0; ii < num_iters; ++ii) {
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
if (thread_offset < cute::size<0>(operand_layout)) {
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
}
}
}
template <
class QuantizedElement,
class DequantizedElement,
class OperandLayout,
class ElementScale,
class ElementZero,
class ScaleLayout>
static void dequantize(DequantizedElement* dq_buffer,
QuantizedElement const* q_buffer,
OperandLayout const operand_layout,
ElementScale const* scale_buffer,
ElementZero const* zero_buffer,
ScaleLayout const scale_layout,
int const group_size,
cudaStream_t &stream) {
using namespace cute;
constexpr int tpb = 128;
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
const auto num_rows = get<0>(shape(operand_layout));
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
if (num_rows != size<0>(scale_layout)) {
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
<< std::endl;
exit(-1);
}
const auto scale_stride0 = get<0>(stride(scale_layout));
const auto scale_stride1 = get<1>(stride(scale_layout));
const auto scale_stride2 = get<2>(stride(scale_layout));
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
const auto blocks_x = gemm_k;
const auto blocks_y = batches;
dim3 blocks(blocks_x, blocks_y, 1);
dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
CUDA_CHECK(cudaStreamSynchronize(stream));
}
template <typename T>
class packed_scale_t {
public:
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
cute::is_same_v<T, cutlass::uint8_t> ||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
cute::is_same_v<T, cutlass::float_e5m2_t>,
"only 8 bit arithmetic types are supported.");
CUTLASS_HOST_DEVICE
explicit packed_scale_t(T val) {
if constexpr (!cute::is_unsigned_v<T>) {
// Only pack negative values. The positive values are generated in flight in the mainloop.
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
}
else {
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
}
}
CUTLASS_HOST_DEVICE
packed_scale_t() = default;
CUTLASS_HOST_DEVICE
explicit operator float() const {
return float(get());
}
CUTLASS_HOST_DEVICE
bool operator==(packed_scale_t const& rhs) const {
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
}
CUTLASS_HOST_DEVICE
bool operator!=(packed_scale_t const& rhs) const {
return !(*this == rhs);
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() + rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() - rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() * rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() / rhs.get());
}
private:
using Storage = uint32_t;
using Stage = uint8_t;
Storage storage[2] {};
CUTLASS_HOST_DEVICE
static Storage pack4(T c1, T c2, T c3, T c4) {
Storage result = 0;
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
return result;
}
CUTLASS_HOST_DEVICE
T get() const {
auto stage = static_cast<Stage>(storage[0] >> 8);
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
CUTLASS_HOST_DEVICE
T get(int idx) const {
Stage stage;
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
};
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
// Here the encodings of positive values and negative values are unified (except for the sign bit).
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
using StorageType = cutlass::int4b_t::Storage;
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
const size_t host_buf_size = block_size / pack;
std::vector<StorageType> host_buf(host_buf_size);
cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
for (auto&& d : host_buf) {
StorageType out = 0;
StorageType mask = 0x0f;
for (int i = 0; i < pack; i++) {
cutlass::int4b_t curr;
curr.storage = (d >> (i * 4)) & 0x0f;
switch (curr) {
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
default: break;
}
out |= (curr.storage << (4 * i)) & mask;
mask <<= 4;
}
d = out;
}
cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
return true;
}
template <class ElementScale>
static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
std::vector<ElementScale> data_in(block_size);
std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
try {
cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
}
catch (cutlass::cuda_exception const& e) {
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
for (size_t i = 0; i < block_size; i++) {
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
}
try {
cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
}
catch (cutlass::cuda_exception const& e) {
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
return true;
}
template <class T, class = void>
struct UnderlyingElement {
using type = T;
};
template <class T>
struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
using type = typename T::Element;
};
// Given a type of MMA instruction, compute a memory reordering atom that places all values
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
// In addition, we can reorder the values across several MMA instructions to get even wider
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
// more optimal conversion instruction sequences (ValLayout parameter).
template <class ElementMma,
class AtomLayout = cute::Layout<cute::_1>,
class ValLayout = cute::Layout<cute::_1>>
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
{
using namespace cute;
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
// 1. Choose an MMA atom to access TV layout and MN shape
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
using MmaTraits = MMA_Traits<MmaAtom>;
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
auto tv_layout_mma = typename MmaTraits::ALayout{};
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
// Note: this assumes A is partitioned between warps along M mode
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
return layout_atom;
}
template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
__global__ void reorder_tensor_kernel(
cute::Tensor<EngineSrc, LayoutSrc> S,
cute::Tensor<EngineDst, LayoutDst> D,
TiledCopy tiled_copy)
{
using namespace cute;
using T = typename EngineDst::value_type;
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
Tensor tS = thread_copy.partition_S(gS);
Tensor tD = thread_copy.partition_D(gD);
copy(tiled_copy, tS, tD);
}
template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
void reorder_tensor(
cute::Tensor<EngineSrc, LayoutSrc> S,
cute::Tensor<EngineDst, LayoutDst> D)
{
using namespace cute;
using T = typename EngineDst::value_type;
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
auto has_major_mode = [](auto s) {
return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
};
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
"Could not find stride-1 mode in destination layout");
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
// Make a tiled copy with a simple row-major thread order and above layout
int constexpr NumThreads = 128;
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
using TileShape = Shape<_16>;
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
CUDA_CHECK(cudaDeviceSynchronize());
}
// In-place version
template <class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T const* src,
LayoutSrc const& layout_src,
T * dst,
LayoutDst const& layout_dst)
{
using namespace cute;
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
}
// In-place version
template <class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T * data,
LayoutSrc const& layout_src,
LayoutDst const& layout_dst)
{
using namespace cute;
cutlass::DeviceAllocation<T> temp(size(layout_src));
reorder_tensor(data, layout_src, temp.get(), layout_dst);
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
}
#undef CUDA_CHECK
} // namespace cutlass

View File

@ -388,7 +388,7 @@ void compute_1d_scaling_factor_and_quantized_output(
absolute_value_op<ElementCompute> abs_op;
maximum_with_nan_propogation<ElementCompute> max_op;
if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) {
if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) {
// MN major output
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
// Col major output
@ -705,7 +705,7 @@ void gett_epilogue(
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
// per-row alpha
// vector alpha
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l));
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
@ -719,7 +719,7 @@ void gett_epilogue(
if (raw_pointer_cast(epilogue_params.C.data())) {
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
// per-row beta
// vector beta
if (epilogue_params.Vbeta.data()) {
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l));
converted_beta = mul(converted_beta, converted_scale_c);