cutlass 3.9 update (#2255)

* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-24 12:42:40 -07:00
committed by GitHub
parent 8e345c5c5b
commit 331a1f5b3f
143 changed files with 18089 additions and 5935 deletions

View File

@ -233,6 +233,10 @@ cutlass_add_cutlass_library(
src/reference/gemm_f6_f8_f32.cu
src/reference/gemm_f8_f4_f32.cu
src/reference/gemm_f8_f6_f32.cu
src/reference/blockwise_gemm_fp8_fp16out.cu
src/reference/blockwise_gemm_fp8_fp32out.cu
src/reference/blockwise_gemm_fp8_bf16out.cu
src/reference/gemm_s8_s8_s32.cu
src/reference/gemm_u8_u8_s32.cu

View File

@ -313,10 +313,16 @@ struct BlockScaleDescription {
TensorDescription SFD;
/// Describes the input ScaleFactor VectorSize
int SFVecSize;
int SFMVecSize;
int SFNVecSize;
int SFKVecSize;
/// Describes the Output ScaleFactor VectorSize
int EpilogueSFVecSize;
/// Describes the underlying kind of scaling:
/// Tensor Core supported (BlockScaled) or manual scaling (Blockwise)
OperationKind kind;
};
struct GroupedGemmDescription : public OperationDescription {
@ -418,6 +424,96 @@ struct BlockScaledGemmDescription : public OperationDescription {
transform_B(transform_B) {}
};
/// Description of all GEMM computations
struct BlockwiseGemmDescription : public OperationDescription {
/// Indicates the kind of GEMM performed
GemmKind gemm_kind;
/// Describes the A operand
TensorDescription A;
/// Describes the B operand
TensorDescription B;
/// Describes the source matrix
TensorDescription C;
/// Describes the destination matrix
TensorDescription D;
/// Describes the SFA operand
TensorDescription SFA;
/// Describes the SFB operand
TensorDescription SFB;
/// Describes the data type of the scalars passed to the epilogue
NumericTypeID element_epilogue;
/// Describes the structure of parallel reductions
SplitKMode split_k_mode;
/// Transformation on A operand
ComplexTransform transform_A;
/// Transformation on B operand
ComplexTransform transform_B;
/// Describes the input ScaleFactor VectorSize
int SFMVecSize;
int SFNVecSize;
int SFKVecSize;
//
// Methods
//
BlockwiseGemmDescription(
GemmKind gemm_kind = GemmKind::kGemm,
TensorDescription const& A = TensorDescription(),
TensorDescription const& B = TensorDescription(),
TensorDescription const& C = TensorDescription(),
TensorDescription const& D = TensorDescription(),
NumericTypeID element_epilogue = NumericTypeID::kInvalid,
SplitKMode split_k_mode = SplitKMode::kNone,
ComplexTransform transform_A = ComplexTransform::kNone,
ComplexTransform transform_B = ComplexTransform::kNone
):
gemm_kind(gemm_kind),
A(A),
B(B),
C(C),
D(D),
element_epilogue(element_epilogue),
split_k_mode(split_k_mode),
transform_A(transform_A),
transform_B(transform_B) {}
BlockwiseGemmDescription(
OperationDescription op_desc,
GemmKind gemm_kind,
TensorDescription const& A,
TensorDescription const& B,
TensorDescription const& C,
TensorDescription const& D,
NumericTypeID element_epilogue,
SplitKMode split_k_mode,
ComplexTransform transform_A,
ComplexTransform transform_B
):
OperationDescription(op_desc),
gemm_kind(gemm_kind),
A(A),
B(B),
C(C),
D(D),
element_epilogue(element_epilogue),
split_k_mode(split_k_mode),
transform_A(transform_A),
transform_B(transform_B) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Description for structured sparse GEMMs.

View File

@ -121,6 +121,13 @@ public:
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
// Set arguments that should only be set once before verifying or profiling the kernel.
// This should encompass any expensive operations that don't vary from run to run
// (e.g., max_active_clusters).
virtual Status initialize_with_arguments(void* arguments_ptr) const {
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -389,6 +396,56 @@ struct BlockScaledGemmArguments {
bool use_pdl{false};
};
/// Blockwise GEMM
//
// OperationKind: kBlockwiseGemm
// GemmKind: Universal
struct BlockwiseGemmArguments {
// NOTE: these are replicated for 3.0 interfaces
gemm::GemmCoord problem_size{};
gemm::GemmCoord cluster_shape{};
gemm::GemmCoord cluster_shape_fallback{};
int batch_count{1};
void const *A{nullptr};
void const *B{nullptr};
void const *SFA{nullptr};
void const *SFB{nullptr};
void const *C{nullptr};
void *D{nullptr};
void const *alpha{nullptr};
void const *beta{nullptr};
ScalarPointerMode pointer_mode{};
// NOTE: these are replicated for 3.0 interfaces
int64_t lda{0};
int64_t ldb{0};
int64_t ldc{0};
int64_t ldd{0};
int64_t batch_stride_A{0};
int64_t batch_stride_B{0};
int64_t batch_stride_C{0};
int64_t batch_stride_D{0};
int sf_m_vec_size{0};
int sf_n_vec_size{0};
int sf_k_vec_size{0};
// Needed for some 3.x kernels
int sm_count{0};
library::RasterOrder raster_order{};
int swizzle_size{1};
int split_k_slices{1};
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
bool use_pdl{false};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -521,6 +578,8 @@ struct GemmGroupedArguments {
// these should really be in the configuration but staying consistent with GEMM
int sm_count{0};
int max_active_clusters{0};
// The user is responsible for allocating storage for problem sizes.
// Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we
// unfortunately need to have both options in this struct, and the
@ -536,6 +595,12 @@ struct GroupedGemmBlockScaledArguments : GemmGroupedArguments {
void* norm_constant{nullptr};
};
struct GroupedGemmBlockwiseArguments : GemmGroupedArguments {
void* SFA{nullptr};
void* SFB{nullptr};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// OperationKind: kSparseGemm

View File

@ -427,6 +427,183 @@ using BlockScaledGemmOperationFunctionalMap = std::unordered_map<
BlockScaledGemmFunctionalKeyHasher
>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Data Structures for Blockwise Gemm Functional Maps
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tuple uniquely identifying Gemm functional behavior
struct BlockwiseGemmFunctionalKey {
Provider provider;
GemmKind gemm_kind;
OperationKind kind;
NumericTypeID element_compute;
NumericTypeID element_scalar;
NumericTypeID element_A;
LayoutTypeID layout_A;
NumericTypeID element_SFA;
NumericTypeID element_B;
LayoutTypeID layout_B;
NumericTypeID element_SFB;
NumericTypeID element_C;
LayoutTypeID layout_C;
NumericTypeID element_D;
LayoutTypeID layout_D;
int SFMVecSize;
int SFNVecSize;
int SFKVecSize;
//
// Methods
//
inline
BlockwiseGemmFunctionalKey(
Provider provider,
GemmKind gemm_kind = GemmKind::kGemm,
OperationKind kind = OperationKind::kBlockwiseGemm,
NumericTypeID element_compute = NumericTypeID::kF32,
NumericTypeID element_scalar = NumericTypeID::kF32,
NumericTypeID element_A = NumericTypeID::kF16,
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
NumericTypeID element_SFA = NumericTypeID::kF16,
NumericTypeID element_B = NumericTypeID::kF16,
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
NumericTypeID element_SFB = NumericTypeID::kF16,
NumericTypeID element_C = NumericTypeID::kF16,
LayoutTypeID layout_C = LayoutTypeID::kColumnMajor,
NumericTypeID element_D = NumericTypeID::kF16,
LayoutTypeID layout_D = LayoutTypeID::kColumnMajor,
int sfm_vec_size = 32,
int sfn_vec_size = 32,
int sfk_vec_size = 32
):
provider(provider),
gemm_kind(gemm_kind),
kind(kind),
element_compute(element_compute),
element_scalar(element_scalar),
element_A(element_A),
layout_A(layout_A),
element_SFA(element_SFA),
element_B(element_B),
layout_B(layout_B),
element_SFB(element_SFB),
element_C(element_C),
layout_C(layout_C),
element_D(element_D),
layout_D(layout_D),
SFMVecSize(sfm_vec_size),
SFNVecSize(sfn_vec_size),
SFKVecSize(sfk_vec_size)
{ }
inline
bool operator==(BlockwiseGemmFunctionalKey const &rhs) const {
return
(provider == rhs.provider) &&
(gemm_kind == rhs.gemm_kind) &&
(kind == rhs.kind) &&
(element_compute == rhs.element_compute) &&
(element_scalar == rhs.element_scalar) &&
(element_A == rhs.element_A) &&
(layout_A == rhs.layout_A) &&
(element_SFA == rhs.element_SFA) &&
(element_B == rhs.element_B) &&
(layout_B == rhs.layout_B) &&
(element_SFB == rhs.element_SFB) &&
(element_C == rhs.element_C) &&
(layout_C == rhs.layout_C) &&
(element_D == rhs.element_D) &&
(layout_D == rhs.layout_D) &&
(SFMVecSize == rhs.SFMVecSize) &&
(SFNVecSize == rhs.SFNVecSize) &&
(SFKVecSize == rhs.SFKVecSize);
}
inline
bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const {
return !(*this == rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
inline
std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) {
out << "{\n"
<< " provider: " << to_string(k.provider) << "\n"
<< " gemm_kind: " << to_string(k.gemm_kind) << "\n"
<< " kind: " << to_string(k.kind) << "\n"
<< " element_compute: " << to_string(k.element_compute) << "\n"
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
<< " element_A: " << to_string(k.element_A) << "\n"
<< " layout_A: " << to_string(k.layout_A) << "\n"
<< " element_SFA: " << to_string(k.element_SFA) << "\n"
<< " element_B: " << to_string(k.element_B) << "\n"
<< " layout_B: " << to_string(k.layout_B) << "\n"
<< " element_SFB: " << to_string(k.element_SFB) << "\n"
<< " element_C: " << to_string(k.element_C) << "\n"
<< " layout_C: " << to_string(k.layout_C) << "\n"
<< " element_D: " << to_string(k.element_D) << "\n"
<< " layout_D: " << to_string(k.layout_D) << "\n"
<< " SFMVecSize: " << k.SFMVecSize << "\n"
<< " SFNVecSize: " << k.SFNVecSize << "\n"
<< " SFKVecSize: " << k.SFKVecSize << "\n"
<< "}";
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Hash function for BlockwiseGemmFunctionalKeyHasher
struct BlockwiseGemmFunctionalKeyHasher {
using IntHash = std::hash<int>;
inline
static size_t rotl(size_t key, int shl) {
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
}
inline
size_t operator()(BlockwiseGemmFunctionalKey const &key) const {
IntHash hash;
return
rotl(hash(int(key.provider)), 1) ^
rotl(hash(int(key.gemm_kind)), 2) ^
rotl(hash(int(key.kind)), 3) ^
rotl(hash(int(key.element_compute)), 4) ^
rotl(hash(int(key.element_scalar)), 5) ^
rotl(hash(int(key.element_A)), 6) ^
rotl(hash(int(key.layout_A)), 7) ^
rotl(hash(int(key.element_SFA)), 8) ^
rotl(hash(int(key.element_B)), 9) ^
rotl(hash(int(key.layout_B)), 10) ^
rotl(hash(int(key.element_SFB)), 11) ^
rotl(hash(int(key.element_C)), 12) ^
rotl(hash(int(key.layout_C)), 13) ^
rotl(hash(int(key.element_D)), 14) ^
rotl(hash(int(key.layout_D)), 15) ^
rotl(hash(int(key.SFMVecSize)), 16) ^
rotl(hash(int(key.SFNVecSize)), 17) ^
rotl(hash(int(key.SFKVecSize)), 18)
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
using BlockwiseGemmOperationFunctionalMap = std::unordered_map<
BlockwiseGemmFunctionalKey,
GemmOperationVectorMap,
BlockwiseGemmFunctionalKeyHasher
>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Data Structures for Conv Functional Maps
@ -697,6 +874,9 @@ public:
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations;
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations;
/// Map of all operations of type kConv2d
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
ConvOperationFunctionalMap conv2d_operations;

View File

@ -143,6 +143,7 @@ enum class Provider {
enum class OperationKind {
kGemm,
kBlockScaledGemm,
kBlockwiseGemm,
kRankK,
kRank2K,
kTrmm,

View File

@ -0,0 +1,429 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines operations for all GEMM operation kinds in CUTLASS Library.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/library/library.h"
#include "library_internal.h"
#include "gemm_operation_3x.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementA = typename Operator::CollectiveMainloop::ElementA;
using ElementSFA = typename Operator::ElementAccumulator;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::CollectiveMainloop::ElementB;
using ElementSFB = typename Operator::ElementAccumulator;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
using CollectiveMainloop = typename Operator::CollectiveMainloop;
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
private:
BlockwiseGemmDescription description_;
public:
/// Constructor
BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
description_.kind = OperationKind::kBlockwiseGemm;
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ?
LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
description_.SFA.alignment = CollectiveMainloop::AlignmentSFA;
description_.SFA.log_extent_range = 32;
description_.SFA.log_stride_range = 32;
description_.SFB.element = NumericTypeMap<ElementSFB>::kId;
description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ?
LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor;
description_.SFB.alignment = CollectiveMainloop::AlignmentSFA;
description_.SFB.log_extent_range = 32;
description_.SFB.log_stride_range = 32;
description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM;
description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN;
description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK;
description_.name = name;
description_.provider = Provider::kCUTLASS;
description_.gemm_kind = GemmKind::kUniversal;
description_.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM,
Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
description_.tile_description.cluster_shape = make_Coord(
Operator::ClusterShape::kM,
Operator::ClusterShape::kN,
Operator::ClusterShape::kK);
}
description_.tile_description.threadblock_stages = Operator::kStages;
description_.tile_description.warp_count = make_Coord(
Operator::WarpCount::kM,
Operator::WarpCount::kN,
Operator::WarpCount::kK);
description_.tile_description.math_instruction.instruction_shape = make_Coord(
Operator::InstructionShape::kM,
Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
description_.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
description_.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::MathOperator>::kId;
description_.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
description_.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
description_.D = make_TensorDescription<ElementD, LayoutD>(Operator::kAlignmentD);
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.split_k_mode = SplitKMode::kNone;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
/// Returns the description of the GEMM operation
BlockwiseGemmDescription const& get_gemm_description() const {
return description_;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) {
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
return Status::kSuccess;
}
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
return Status::kSuccess;
}
else {
return Status::kErrorInvalidProblem;
}
}
};
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
BlockwiseGemmArguments const *arguments) {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
operator_args.problem_shape = cute::make_shape(
arguments->problem_size.m(),
arguments->problem_size.n(),
arguments->problem_size.k(),
arguments->batch_count);
// update arguments
if constexpr (IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
std::unordered_map<RuntimeDatatype, cute::UMMA::MXF8F6F4Format> mapping = {
{RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3},
{RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2},
{RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2},
{RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1}
};
auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a);
auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b);
if (iter_runtime_a != mapping.end()) {
operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second;
} else {
assert("invalid runtime argument for datatype A!");
}
if (iter_runtime_b != mapping.end()) {
operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second;
} else {
assert("invalid runtime argument for datatype B!");
}
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
}
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
arguments->ldb, arguments->batch_stride_B);
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
arguments->ldc, arguments->batch_stride_C);
operator_args.epilogue.dD = operator_args.epilogue.dC;
operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape);
operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape);
/* Query device SM count to pass onto the kernel as an argument, where needed */
operator_args.hw_info.sm_count = arguments->sm_count;
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
}
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
using Enum_t = decltype(operator_args.scheduler.raster_order);
switch (arguments->raster_order) {
case RasterOrder::kAlongN:
operator_args.scheduler.raster_order = Enum_t::AlongN;
break;
case RasterOrder::kAlongM:
operator_args.scheduler.raster_order = Enum_t::AlongM;
break;
default:
operator_args.scheduler.raster_order = Enum_t::Heuristic;
}
}
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
operator_args.scheduler.splits = arguments->split_k_slices;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape = dim3(
arguments->cluster_shape.m(),
arguments->cluster_shape.n(),
arguments->cluster_shape.k());
operator_args.hw_info.cluster_shape_fallback = dim3(
arguments->cluster_shape_fallback.m(),
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return status;
}
public:
/// Returns success if the operation can proceed
Status can_implement(
void const *configuration_ptr, void const *arguments_ptr) const override {
GemmUniversalConfiguration const *configuration =
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
BlockwiseGemmArguments const *arguments =
static_cast<BlockwiseGemmArguments const *>(arguments_ptr);
if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) {
return Status::kErrorInvalidProblem;
}
if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) {
return Status::kErrorInvalidProblem;
}
if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) {
return Status::kErrorInvalidProblem;
}
OperatorArguments args;
auto status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
// can_implement rules may need access to problem shape
args.problem_shape = cute::make_shape(
configuration->problem_size.m(),
configuration->problem_size.n(),
configuration->problem_size.k(),
configuration->batch_count);
return Operator::can_implement(args);
}
/// Gets the host-side workspace
uint64_t get_host_workspace_size(void const *configuration) const override {
return sizeof(Operator);
}
/// Gets the device-side workspace
uint64_t get_device_workspace_size(
void const *configuration_ptr,void const *arguments_ptr) const override {
OperatorArguments args;
auto status = update_arguments_(
args, static_cast<BlockwiseGemmArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const override {
Operator *op = new (host_workspace) Operator;
return Status::kSuccess;
}
Status initialize_with_profiler_workspace(
void const *configuration,
void *host_workspace,
void *device_workspace,
uint8_t **profiler_workspaces,
int problem_count_from_profiler,
cudaStream_t stream = nullptr) {
return Status::kSuccess;
}
/// Runs the kernel
Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments args;
Status status = update_arguments_(args, static_cast<BlockwiseGemmArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
status = op->run(args, device_workspace, stream, nullptr, static_cast<BlockwiseGemmArguments const *>(arguments_ptr)->use_pdl);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::library
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -72,19 +72,6 @@ public:
this->description_.gemm = GemmOperation3xBase<Operator_>::description_;
this->description_.tile_description = this->description_.gemm.tile_description;
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
};
public:
@ -102,7 +89,6 @@ public:
protected:
library::GroupedGemmDescription description_;
int max_active_clusters;
Status initialize_strides(GemmGroupedConfiguration const& config) const {
auto const num_groups = config.problem_count;
@ -182,7 +168,7 @@ protected:
operator_args.hw_info.sm_count = arguments.sm_count;
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
operator_args.hw_info.max_active_clusters = max_active_clusters;
operator_args.hw_info.max_active_clusters = arguments.max_active_clusters;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape =
@ -343,6 +329,47 @@ public:
status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl);
return status;
}
// Set arguments that should only be set once before verifying or profiling the kernel.
// This should encompass any expensive operations that don't vary from run to run
// (e.g., max_active_clusters).
Status initialize_with_arguments(void* arguments_ptr) const override {
if constexpr (Operator::ArchTag::kMinComputeCapability < 90) {
return Status::kSuccess;
}
GemmGroupedArguments* args = static_cast<GemmGroupedArguments*>(arguments_ptr);
dim3 cluster_dims;
if constexpr (cute::is_static_v<typename Operator::GemmKernel::ClusterShape>) {
cluster_dims = dim3(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{})
);
}
else {
cluster_dims = dim3(
args->cluster_shape.m(),
args->cluster_shape.n(),
args->cluster_shape.k()
);
}
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
if (args->max_active_clusters == 0) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
};
template <typename Operator_>
@ -375,6 +402,7 @@ public:
: GroupedGemmOperation3xBase<Operator_>(name) {
BlockScaleDescription block_scaled_desc{};
block_scaled_desc.kind = OperationKind::kBlockScaledGemm;
block_scaled_desc.SFA.element = NumericTypeMap<ElementSFA>::kId;
block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor;
block_scaled_desc.SFA.alignment = 128;
@ -387,7 +415,9 @@ public:
block_scaled_desc.SFB.log_extent_range = 32;
block_scaled_desc.SFB.log_stride_range = 32;
block_scaled_desc.SFVecSize = SFVecSize;
block_scaled_desc.SFMVecSize = 1;
block_scaled_desc.SFNVecSize = 1;
block_scaled_desc.SFKVecSize = SFVecSize;
block_scaled_desc.SFD = make_TensorDescription<ElementSFD, LayoutSFD>(128);
block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize;
@ -555,4 +585,206 @@ public:
}
};
template <typename Operator_>
class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using CollectiveMainloop = typename Operator::CollectiveMainloop;
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
using ElementSFA = typename Operator::ElementAccumulator;
using ElementSFB = typename Operator::ElementAccumulator;
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm")
: GroupedGemmOperation3xBase<Operator_>(name) {
BlockScaleDescription blockwise_desc{};
blockwise_desc.kind = OperationKind::kBlockwiseGemm;
blockwise_desc.SFA.element = NumericTypeMap<ElementSFA>::kId;
blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ?
LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA;
blockwise_desc.SFA.log_extent_range = 32;
blockwise_desc.SFA.log_stride_range = 32;
blockwise_desc.SFB.element = NumericTypeMap<ElementSFB>::kId;
blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ?
LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor;
blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA;
blockwise_desc.SFB.log_extent_range = 32;
blockwise_desc.SFB.log_stride_range = 32;
blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM;
blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN;
blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK;
blockwise_desc.EpilogueSFVecSize = 0;
this->description_.block_scales = blockwise_desc;
}
~GroupedBlockwiseGemmUniversal3xOperation() override = default;
mutable CudaBuffer layout_SFA_device;
mutable CudaBuffer layout_SFB_device;
protected:
template <class FusionArgs, class = void> struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template <class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status
update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) {
return GroupedGemmOperation3xBase<Operator>::update_fusion_args(fusion_args, arguments);
}
};
public:
/// Returns success if the operation can proceed
Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr)
const override {
GroupedGemmBlockwiseArguments const* arguments =
static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr);
OperatorArguments args;
auto status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
status = Operator::can_implement(args);
return status;
}
Status update_arguments_(
OperatorArguments& operator_args,
GroupedGemmBlockwiseArguments const* arguments) const {
Status status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread,
*arguments);
if (status != Status::kSuccess) {
return status;
}
operator_args.mainloop.ptr_SFA =
static_cast<const typename Operator::GemmKernel::ElementAccumulator**>(arguments->SFA);
operator_args.mainloop.ptr_SFB =
static_cast<const typename Operator::GemmKernel::ElementAccumulator**>(arguments->SFB);
operator_args.mainloop.layout_SFA =
static_cast<typename CollectiveMainloop::InternalLayoutSFA*>(this->layout_SFA_device.data());
operator_args.mainloop.layout_SFB =
static_cast<typename CollectiveMainloop::InternalLayoutSFB*>(this->layout_SFB_device.data());
return this->update_arguments_base(operator_args, *arguments);
}
uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr)
const override {
OperatorArguments args;
auto status =
update_arguments_(args, static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
/// **** CAUTION ****
/// Must be called when lda, ldb, ldc, or ldd change.
/// The CUTLASS library stores the operations in a type-
/// erased manifest. Therefore, only this class knows
/// the type of strideA, strideB, strideC, and strideD.
/// Since grouped GEMM needs to allocate storage for
/// the strides on device, the concrete type of the stride
/// must be known in order to copy in the correct memory
/// layout on device.
Status initialize(
void const* configuration_ptr,
void* host_workspace,
void* device_workspace,
cudaStream_t stream = nullptr) const override {
auto const& config = *static_cast<GemmGroupedConfiguration const*>(configuration_ptr);
auto status = this->initialize_strides(config);
if (status != Status::kSuccess) {
return status;
}
auto num_groups = config.problem_count;
this->layout_SFA_device =
CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups);
this->layout_SFB_device =
CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups);
auto layout_SFA_host = std::vector<typename CollectiveMainloop::InternalLayoutSFA>(num_groups);
auto layout_SFB_host = std::vector<typename CollectiveMainloop::InternalLayoutSFB>(num_groups);
for (int group_idx = 0; group_idx < num_groups; group_idx++) {
auto const& shape = config.problem_sizes_3x_host[group_idx];
auto M = get<0>(shape);
auto N = get<1>(shape);
auto K = get<2>(shape);
auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
layout_SFA_host[group_idx] = layout_SFA;
layout_SFB_host[group_idx] = layout_SFB;
}
CUDA_CHECK(cudaMemcpy(
this->layout_SFA_device.data(),
layout_SFA_host.data(),
sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(
this->layout_SFB_device.data(),
layout_SFB_host.data(),
sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups,
cudaMemcpyHostToDevice));
Operator* op = new (host_workspace) Operator;
return status;
}
/// **** CAUTION ****
/// initialize() must be called if lda, ldb, ldc, or ldd change.
Status run(
void const* arguments_ptr,
void* host_workspace,
void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments operator_args;
auto const& args = *static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr);
Status status = update_arguments_(operator_args, &args);
if (status != Status::kSuccess) {
return status;
}
Operator* op = static_cast<Operator*>(host_workspace);
status = op->run(operator_args, device_workspace, stream, nullptr);
return status;
}
};
} // namespace cutlass::library

View File

@ -86,6 +86,41 @@ void OperationTable::append(Manifest const &manifest) {
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
}
if (desc.kind == OperationKind::kBlockwiseGemm) {
BlockwiseGemmDescription const &gemm_desc = static_cast<BlockwiseGemmDescription const &>(desc);
BlockwiseGemmFunctionalKey functional_key(
gemm_desc.provider,
gemm_desc.gemm_kind,
gemm_desc.kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.SFA.element,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
gemm_desc.SFMVecSize,
gemm_desc.SFNVecSize,
gemm_desc.SFKVecSize
);
Operation const *op = operation.get();
int cc = gemm_desc.tile_description.minimum_compute_capability;
int alignment = std::max(std::max(
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
GemmPreferenceKey preference_key(cc, alignment);
blockwise_gemm_operations[functional_key][preference_key].push_back(op);
}
// insert all gemm operation into operation table
if (desc.kind == OperationKind::kGemm) {
@ -157,29 +192,57 @@ void OperationTable::append(Manifest const &manifest) {
}
else {
const BlockScaleDescription &block_scale_desc = grouped_gemm_desc.block_scales.value();
BlockScaledGemmFunctionalKey functional_key(
gemm_desc.provider,
gemm_desc.gemm_kind,
gemm_desc.kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
block_scale_desc.SFA.element,
gemm_desc.B.element,
gemm_desc.B.layout,
block_scale_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
block_scale_desc.SFD.element,
block_scale_desc.SFD.layout,
block_scale_desc.SFVecSize,
block_scale_desc.EpilogueSFVecSize
);
if (block_scale_desc.kind == OperationKind::kBlockScaledGemm) {
BlockScaledGemmFunctionalKey functional_key(
gemm_desc.provider,
gemm_desc.gemm_kind,
gemm_desc.kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
block_scale_desc.SFA.element,
gemm_desc.B.element,
gemm_desc.B.layout,
block_scale_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
block_scale_desc.SFD.element,
block_scale_desc.SFD.layout,
block_scale_desc.SFKVecSize,
block_scale_desc.EpilogueSFVecSize
);
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
}
else {
assert(block_scale_desc.kind == OperationKind::kBlockwiseGemm);
BlockwiseGemmFunctionalKey functional_key(
gemm_desc.provider,
gemm_desc.gemm_kind,
gemm_desc.kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
block_scale_desc.SFA.element,
gemm_desc.B.element,
gemm_desc.B.layout,
block_scale_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
block_scale_desc.SFMVecSize,
block_scale_desc.SFNVecSize,
block_scale_desc.SFKVecSize
);
blockwise_gemm_operations[functional_key][preference_key].push_back(op);
}
}
}

View File

@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "blockwise_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_blockwise_gemm_reference_operations_bf16out(Manifest &manifest) {
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, bfloat16_t>(manifest);
initialize_blockwise_gemm_reference_operations_given_C_and_D<bfloat16_t, bfloat16_t>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "blockwise_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_blockwise_gemm_reference_operations_fp16out(Manifest &manifest) {
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, half_t>(manifest);
initialize_blockwise_gemm_reference_operations_given_C_and_D<half_t, half_t>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "blockwise_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_blockwise_gemm_reference_operations_fp32out(Manifest &manifest) {
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, float>(manifest);
initialize_blockwise_gemm_reference_operations_given_C_and_D<float, float>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,664 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library
*/
#pragma once
#include <iostream>
#include <sstream>
#include <cstring>
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "cutlass/util/packed_stride.hpp"
#include "library_internal.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/detail/blockwise_scale_layout.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
Provider Provider_,
typename ElementA_,
typename LayoutA_,
typename LayoutSFA_,
typename ElementSFA_,
typename ElementB_,
typename LayoutB_,
typename LayoutSFB_,
typename ElementSFB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
class BlockwiseGemmReferenceOperation : public Operation {
public:
static Provider const kProvider = Provider_;
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using ElementSFA = ElementSFA_;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using ElementSFB = ElementSFB_;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using ElementD = ElementD_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementAccumulator_;
using ConvertOp = ConvertOp_;
using InnerProductOp = InnerProductOp_;
protected:
/// Storage for the name string
std::string name_;
///
BlockwiseGemmDescription description_;
public:
/// Constructor
BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_)
: SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) {
// Basic information
description_.provider = kProvider;
description_.kind = OperationKind::kBlockwiseGemm;
description_.gemm_kind = GemmKind::kUniversal;
// Tensor description
description_.A = make_TensorDescription<ElementA, LayoutA>();
description_.SFA = make_TensorDescription<ElementSFA, LayoutSFA_>();
description_.B = make_TensorDescription<ElementB, LayoutB>();
description_.SFB = make_TensorDescription<ElementSFB, LayoutSFB_>();
description_.C = make_TensorDescription<ElementC, LayoutC>();
description_.D = make_TensorDescription<ElementD, LayoutC>();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
// Compute capability for gemm reference
description_.tile_description.minimum_compute_capability =
(kProvider == Provider::kReferenceDevice ? 50 : 0);
description_.tile_description.maximum_compute_capability = 1024;
description_.SFMVecSize = SFMVecSize;
description_.SFNVecSize = SFNVecSize;
description_.SFKVecSize = SFKVecSize;
// Procedural name
std::stringstream ss;
ss << "gemm"
<< "_reference_" << to_string(description_.provider)
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
<< "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout)
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
<< "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout)
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
name_ = ss.str();
description_.name = name_.c_str();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
virtual Status can_implement(
void const *configuration,
void const *arguments) const {
return Status::kSuccess;
}
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(GemmUniversalConfiguration);
}
virtual uint64_t get_device_workspace_size(
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
return Status::kSuccess;
}
virtual Status run(
void const *arguments,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
using namespace cute;
BlockwiseGemmArguments const &args = *static_cast<BlockwiseGemmArguments const *>(arguments);
// Construct cute::Tensor A/B/C
int M = args.problem_size.m();
int N = args.problem_size.n();
int K = args.problem_size.k();
int L = args.batch_count;
auto problem_shape_MNKL = cute::make_shape(M, N, K, L);
auto alpha = *(static_cast<ElementCompute const*>(args.alpha));
auto beta = *(static_cast<ElementCompute const*>(args.beta));
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutC>;
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>;
auto A = cute::make_tensor(static_cast<ElementA const*>(args.A),
cute::make_layout(cute::make_shape(M, K, L), stride_a));
auto SfA = make_tensor(static_cast<ElementSFA const*>(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
auto B = cute::make_tensor(static_cast<ElementB const*>(args.B),
cute::make_layout(cute::make_shape(N, K, L), stride_b));
auto SfB = make_tensor(static_cast<ElementSFB const*>(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
auto C = [&]() {
if constexpr (not is_same_v<ElementC, void>) {
return cute::make_tensor(static_cast<ElementC const*>(args.C),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
else {
return cute::make_tensor(static_cast<ElementD const*>(nullptr),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
}();
auto D = cute::make_tensor(static_cast<ElementD *>(args.D),
cute::make_layout(cute::make_shape(M, N, L), stride_d));
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
decltype(A), decltype(SfA),
decltype(B), decltype(SfB)>
mainloop_params{A, SfA, B, SfB};
// W/O SF generation
cutlass::reference::host::GettEpilogueParams<
ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute,
decltype(C), decltype(D)>
epilogue_params{alpha, beta, C, D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
return Status::kSuccess;
}
private:
int SFMVecSize;
int SFNVecSize;
int SFKVecSize;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_,
typename ElementSFA_,
typename ElementB_,
typename ElementSFB_,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) {
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
}
template<class ElementC,
class ElementD>
void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) {
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -64,6 +64,10 @@ void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
void initialize_blockwise_gemm_reference_operations_fp32out(Manifest &manifest);
void initialize_blockwise_gemm_reference_operations_fp16out(Manifest &manifest);
void initialize_blockwise_gemm_reference_operations_bf16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
@ -113,6 +117,9 @@ void initialize_reference_operations(Manifest &manifest) {
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
initialize_blockwise_gemm_reference_operations_fp32out(manifest);
initialize_blockwise_gemm_reference_operations_fp16out(manifest);
initialize_blockwise_gemm_reference_operations_bf16out(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -334,6 +334,7 @@ static struct {
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
{"gemm", "Gemm", OperationKind::kGemm},
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
{"blockwise_gemm", "blockwiseGemm", OperationKind::kBlockwiseGemm},
{"rank_k", "RankK", OperationKind::kRankK},
{"rank_2k", "Rank2K", OperationKind::kRank2K},
{"trmm", "Trmm", OperationKind::kTrmm},

View File

@ -48,6 +48,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
src/gemm_operation_profiler.cu
src/grouped_gemm_operation_profiler.cu
src/block_scaled_gemm_operation_profiler.cu
src/blockwise_gemm_operation_profiler.cu
src/rank_k_operation_profiler.cu
src/rank_2k_operation_profiler.cu
src/trmm_operation_profiler.cu

View File

@ -0,0 +1,290 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Blockscale Gemm Profiler
*/
#pragma once
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <unordered_map>
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "cutlass/library/manifest.h"
// Profiler includes
#include "options.h"
#include "device_context.h"
#include "operation_profiler.h"
#include "performance_result.h"
#include "problem_space.h"
#include "reduction_operation_profiler.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Abstract base class for each math function
class BlockwiseGemmOperationProfiler : public OperationProfiler {
public:
/// Problem structure obtained from problem space
struct GemmProblem {
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm};
int64_t m{16};
int64_t n{16};
int64_t k{16};
int64_t sf_vec_m{0};
int64_t sf_vec_n{0};
int64_t sf_vec_k{0};
int cluster_m{1};
int cluster_n{1};
int cluster_k{1};
int cluster_m_fallback{1};
int cluster_n_fallback{1};
int cluster_k_fallback{1};
int64_t lda{0};
int64_t ldb{0};
int64_t ldc{0};
std::vector<uint8_t> alpha;
std::vector<uint8_t> beta;
cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone};
int split_k_slices{1};
int batch_count{1};
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
int swizzle_size{1};
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
// gemm with parallel interleaved reduction
// gemm epilogue (alpha, beta) = (1.0, 0.0)
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
std::vector<uint8_t> alpha_one;
std::vector<uint8_t> beta_zero;
bool use_pdl{false};
//
// Methods
//
/// Parses the problem
Status parse(
library::BlockwiseGemmDescription const &operation_desc,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Total number of bytes loaded
int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const;
/// Total number of flops computed
int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const;
/// Initializes a performance result
void initialize_result(
PerformanceResult &result,
library::BlockwiseGemmDescription const &operation_desc,
ProblemSpace const &problem_space);
};
/// Workspace used
struct GemmWorkspace {
DeviceAllocation *A{nullptr};
DeviceAllocation *SFA{nullptr};
DeviceAllocation *B{nullptr};
DeviceAllocation *SFB{nullptr};
DeviceAllocation *C{nullptr};
DeviceAllocation *Computed{nullptr};
DeviceAllocation *Reference{nullptr};
/// Number of copies of the problem workspace which are visited sequentially during
/// profiling to avoid camping in the last level cache.
int problem_count{1};
library::GemmUniversalConfiguration configuration;
library::BlockwiseGemmArguments arguments;
/// Buffer used for the operation's host workspace
std::vector<uint8_t> host_workspace;
/// Buffer used for the operations' device workspace
DeviceAllocation device_workspace;
/// Library configuration and arguments for reduction operator
library::ReductionConfiguration reduction_configuration;
library::ReductionArguments reduction_arguments;
/// Buffer used for the cutlass reduction operations' host workspace
std::vector<uint8_t> reduction_host_workspace;
};
protected:
//
// Data members
//
/// GEMM problem obtained from problem space
GemmProblem problem_;
/// Device memory allocations
GemmWorkspace gemm_workspace_;
/// CUTLASS parallel reduction operation to follow this* gemm operation
library::Operation const *reduction_op_;
public:
//
// Methods
//
/// Ctor
BlockwiseGemmOperationProfiler(Options const &options);
/// Destructor
virtual ~BlockwiseGemmOperationProfiler();
GemmProblem const& problem() const { return problem_; }
/// Prints usage statement for the math function
virtual void print_usage(std::ostream &out) const;
/// Prints examples
virtual void print_examples(std::ostream &out) const;
/// Extracts the problem dimensions
virtual Status initialize_configuration(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Initializes workspace
virtual Status initialize_workspace(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Verifies CUTLASS against references
virtual bool verify_cutlass(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Measures performance results
virtual bool profile(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
protected:
/// Initializes the performance result
void initialize_result_(
PerformanceResult &result,
Options const &options,
library::BlockwiseGemmDescription const &operation_desc,
ProblemSpace const &problem_space);
/// Verifies CUTLASS against references
bool verify_with_cublas_(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Verifies CUTLASS against host and device references
bool verify_with_reference_(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem,
cutlass::library::NumericTypeID element_A,
cutlass::library::NumericTypeID element_B);
/// Method to profile a CUTLASS Operation
Status profile_cutlass_(
PerformanceResult &result,
Options const &options,
library::Operation const *operation,
void *arguments,
void *host_workspace,
void *device_workspace);
/// Initialize reduction problem dimensions and library::Operation
bool initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -198,6 +198,11 @@ private:
arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data();
arguments.norm_constant = block_scaled_ws.norm_constant->data();
}
else if (is_blockwise) {
auto& block_scaled_ws = gemm_workspace_.block_scales.value();
arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data();
arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data();
}
}
protected:
@ -208,6 +213,7 @@ protected:
GroupedGemmWorkspace gemm_workspace_;
bool is_block_scaled{false};
bool is_blockwise{false};
public:
GroupedGemmOperationProfiler(Options const& options);

View File

@ -437,9 +437,11 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "k", problem_space, k);
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@
// Profiler includes
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
#include "cutlass/profiler/blockwise_gemm_operation_profiler.h"
#include "cutlass/profiler/conv2d_operation_profiler.h"
#include "cutlass/profiler/conv3d_operation_profiler.h"
#include "cutlass/profiler/cutlass_profiler.h"
@ -64,6 +65,8 @@ CutlassProfiler::CutlassProfiler(
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));

View File

@ -440,10 +440,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "n", problem_space, n);
set_argument(result, "k", problem_space, k);
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);

View File

@ -40,6 +40,7 @@
#include <stdexcept>
#include <string>
#include <vector>
#include <regex>
#include <cuda_runtime_api.h>
@ -459,9 +460,11 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
set_argument(result, "problem-sizes", problem_space, ss.str());
}
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
auto cluster_shape = operation_desc.gemm.tile_description.cluster_shape;
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
@ -497,10 +500,22 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
// We distinguish between block scaled and non-block scaled operations by looking at the kernel
// name, which tells us what reference kernel to use, which arguments to pass to the operation
// etc. This avoids creating yet another OperationProfiler with a lot of boilerplate in it.
std::string sf_tuple = "\\d+x\\d+";
std::string datatypes_regex = "\\w?f\\d+|e\\dm\\d"; // bf16 | f16 | f32 | e4m3 | ...
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
datatypes_regex + ")_" + sf_tuple + "(" +
datatypes_regex + ")x(" + datatypes_regex + ")";
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
is_block_scaled = true;
gemm_workspace_.block_scales = BlockScalingWorkspace{};
}
else if (std::regex_search(operation_desc.gemm.name, std::regex(blockwise_regex_string))) {
is_blockwise = true;
gemm_workspace_.block_scales = BlockScalingWorkspace{};
}
else {
is_block_scaled = false;
gemm_workspace_.block_scales = std::nullopt;
@ -605,6 +620,12 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
block_scaling_ws.SFD_ptr_array_host.resize(num_groups);
block_scaling_ws.SFD_reference_ptr_array_host.resize(num_groups);
}
else if (is_blockwise) {
auto& block_scaling_ws = gemm_workspace_.block_scales.value();
block_scaling_ws.SFA_ptr_array_host.resize(num_groups);
block_scaling_ws.SFB_ptr_array_host.resize(num_groups);
block_scaling_ws.SFC_ptr_array_host.resize(num_groups);
}
static_assert(sizeof(void*) == 8); // allocating blocks for pointers, so verify pointer size
// ldx
gemm_workspace_.lda_array_device =
@ -698,7 +719,7 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
int sfa_m = round_up(int(problem_.m(group_idx)), 128);
int sfb_n = round_up(int(problem_.n(group_idx)), 128);
int sfa_sfb_k =
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFVecSize), 4);
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize), 4);
int sfd_m =
block_scale_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor
@ -760,6 +781,37 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
block_scale_ws.SFD_ptr_array_host[group_idx]->fill_device(0);
}
}
else if (is_blockwise) {
auto const block_scale_desc = operation_desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
int sfa_m = ceil_div(int(problem_.m(group_idx)), block_scale_desc.SFMVecSize);
int sfb_n = ceil_div(int(problem_.n(group_idx)), block_scale_desc.SFNVecSize);
int sfa_sfb_k = ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize);
block_scale_ws.SFA_ptr_array_host[group_idx] =
device_context.allocate_and_initialize_tensor(
options,
"SFA_" + std::to_string(group_idx),
block_scale_desc.SFA.element,
block_scale_desc.SFA.layout,
{sfa_m, sfa_sfb_k},
{sfa_m},
gemm_workspace_.problem_count,
seed_shift++,
0);
block_scale_ws.SFB_ptr_array_host[group_idx] =
device_context.allocate_and_initialize_tensor(
options,
"SFB_" + std::to_string(group_idx),
block_scale_desc.SFB.element,
block_scale_desc.SFB.layout,
{sfa_sfb_k, sfb_n},
{sfb_n},
gemm_workspace_.problem_count,
seed_shift++,
0);
}
}
// takes the allocated tensors and initializes an array of pointers per problem in the workspace
@ -825,6 +877,18 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
0 // device_index
);
}
else if (is_blockwise) {
auto& block_scale_ws = gemm_workspace_.block_scales.value();
create_dev_ptr_array_all_workspace(
block_scale_ws.SFA_ptr_array_device,
block_scale_ws.SFA_ptr_array_host,
"SFA");
create_dev_ptr_array_all_workspace(
block_scale_ws.SFB_ptr_array_device,
block_scale_ws.SFB_ptr_array_host,
"SFB");
}
init_arguments(options);
}
@ -896,6 +960,11 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
init_arguments(options);
library::Operation const* underlying_operation = operation;
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
if (results_.back().status != Status::kSuccess) {
return false;
}
results_.back().status = underlying_operation->run(
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
@ -998,7 +1067,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
}
// we only have a block scaled reference kernel implemented on the host
if (is_block_scaled && provider != library::Provider::kReferenceHost) {
if ((is_block_scaled || is_blockwise) && provider != library::Provider::kReferenceHost) {
continue;
}
@ -1064,12 +1133,22 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
ptr_norm_constant = host_data_norm_constant.data();
ws.norm_constant->copy_to_host(ptr_norm_constant);
}
else if (is_blockwise) {
auto const& ws = gemm_workspace_.block_scales.value();
host_data_SFA.resize(ws.SFA_ptr_array_host[group_idx]->bytes());
ptr_SFA = host_data_SFA.data();
ws.SFA_ptr_array_host[group_idx]->copy_to_host(ptr_SFA);
host_data_SFB.resize(ws.SFB_ptr_array_host[group_idx]->bytes());
ptr_SFB = host_data_SFB.data();
ws.SFB_ptr_array_host[group_idx]->copy_to_host(ptr_SFB);
}
}
const auto &desc = static_cast<library::GroupedGemmDescription const &>(operation->description());
const auto& gemm_desc = desc.gemm;
if (!is_block_scaled) {
if (!is_block_scaled and !is_blockwise) {
library::Handle handle;
handle.set_provider(provider);
@ -1112,7 +1191,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride());
}
else {
else if (is_block_scaled) {
auto const& block_scale_desc = desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
@ -1134,7 +1213,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
gemm_desc.D.layout,
block_scale_desc.SFD.element,
block_scale_desc.SFD.layout,
block_scale_desc.SFVecSize,
block_scale_desc.SFKVecSize,
block_scale_desc.EpilogueSFVecSize);
auto operators_it =
@ -1208,6 +1287,100 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
block_scale_ws.SFD_reference_ptr_array_host[group_idx]->copy_from_host(ptr_SFD);
}
else {
// Blockwise
auto const& block_scale_desc = desc.block_scales.value();
auto& block_scale_ws = gemm_workspace_.block_scales.value();
library::BlockwiseGemmFunctionalKey blockwiseGemm_key(
library::Provider::kReferenceHost,
library::GemmKind::kUniversal,
library::OperationKind::kBlockwiseGemm,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
element_A,
gemm_desc.A.layout,
block_scale_desc.SFA.element,
element_B,
gemm_desc.B.layout,
block_scale_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
block_scale_desc.SFMVecSize,
block_scale_desc.SFNVecSize,
block_scale_desc.SFKVecSize
);
auto operators_it = library::Singleton::get().operation_table.blockwise_gemm_operations.find(blockwiseGemm_key);
if (
operators_it ==
library::Singleton::get().operation_table.blockwise_gemm_operations.end()) {
disposition = Disposition::kNotSupported;
break;
}
if (operators_it->second.empty()) {
disposition = Disposition::kNotSupported;
break;
}
auto cc_it = operators_it->second.begin();
if (cc_it == operators_it->second.end()) {
disposition = Disposition::kNotSupported;
break;
}
// host reference has only one instances in BlockScaledOperationVectorMap
library::Operation const* reference_op = cc_it->second[0];
library::BlockwiseGemmArguments arguments {
{int(problem_.m(group_idx)), int(problem_.n(group_idx)), int(problem_.k(group_idx))},
{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)},
{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)},
1, // batch_count
ptr_A,
ptr_B,
ptr_SFA,
ptr_SFB,
ptr_C,
ptr_D,
problem_.alpha.data(),
problem_.beta.data(),
library::ScalarPointerMode::kHost,
problem_.lda[group_idx],
problem_.ldb[group_idx],
problem_.ldc[group_idx],
problem_.ldc[group_idx],
gemm_workspace_.A_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.B_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride(),
};
library::GemmUniversalConfiguration configuration{
library::GemmUniversalMode::kGemm,
problem_.problem_sizes[group_idx],
{problem_.cluster_m, problem_.cluster_n, problem_.cluster_k},
{problem_.cluster_m_fallback, problem_.cluster_n_fallback, problem_.cluster_k_fallback},
1,
problem_.lda[group_idx],
problem_.ldb[group_idx],
problem_.ldc[group_idx],
problem_.ldc[group_idx],
1,
};
uint64_t host_workspace_size_needed = reference_op->get_host_workspace_size(&gemm_workspace_.configuration);
std::vector<char> host_workspace(host_workspace_size_needed);
status = reference_op->initialize(&configuration, host_workspace.data());
if (status != Status::kSuccess) {
break;
}
status = reference_op->run(&arguments, host_workspace.data());
}
if (status != Status::kSuccess) {
break;
}
@ -1292,6 +1465,10 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
void* device_workspace) {
library::Operation const* underlying_operation = operation;
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
if (results_.back().status != Status::kSuccess) {
return results_.back().status;
}
auto func = [&](cudaStream_t stream, int iteration) {
// Iterate over copies of the problem in memory

View File

@ -301,6 +301,9 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
out << "kBlockScaledGemm";
}
else if (op_kind == library::OperationKind::kBlockwiseGemm) {
out << "kBlockwiseGemm";
}
else if (op_kind == library::OperationKind::kRankK) {
out << "kRankK";
}