CUTLASS 3.8 Release (#2059)

* CUTLASS 3.8 Release

* update

* Update README.md

* Revert "Update README.md"

This reverts commit b353e36fe8.

* update

* update

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
mihir-awatramani
2025-01-24 23:44:06 -08:00
committed by GitHub
parent 9eb01fa0b0
commit 389e493055
290 changed files with 91223 additions and 292 deletions

View File

@ -221,6 +221,19 @@ cutlass_add_cutlass_library(
# files split for parallel compilation
src/reference/gemm_int4.cu
src/reference/block_scaled_gemm_fp4a_vs16.cu
src/reference/block_scaled_gemm_fp4a_vs32.cu
src/reference/block_scaled_gemm_mixed8bitsa.cu
src/reference/gemm_f4_f4_f32.cu
src/reference/gemm_f4_f6_f32.cu
src/reference/gemm_f4_f8_f32.cu
src/reference/gemm_f6_f4_f32.cu
src/reference/gemm_f6_f6_f32.cu
src/reference/gemm_f6_f8_f32.cu
src/reference/gemm_f8_f4_f32.cu
src/reference/gemm_f8_f6_f32.cu
src/reference/gemm_s8_s8_s32.cu
src/reference/gemm_u8_u8_s32.cu
src/reference/gemm_int8_interleaved_32.cu

View File

@ -119,6 +119,18 @@ template <> struct ArchMap<arch::Sm90, arch::OpClassSparseTensorOp> {
static int const kMax = 90;
};
template <typename OperatorClass> struct ArchMap<arch::Sm100, OperatorClass> {
static int const kMin = 100;
static int const kMax = 1024;
};
template <> struct ArchMap<arch::Sm100, arch::OpClassTensorOp> {
static int const kMin = 100;
static int const kMax = 100;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library

View File

@ -300,6 +300,101 @@ struct GemmDescription : public OperationDescription {
transform_B(transform_B) {}
};
/// Description of all GEMM computations
struct BlockScaledGemmDescription : public OperationDescription {
/// Indicates the kind of GEMM performed
GemmKind gemm_kind;
/// Describes the A operand
TensorDescription A;
/// Describes the B operand
TensorDescription B;
/// Describes the source matrix
TensorDescription C;
/// Describes the destination matrix
TensorDescription D;
/// Describes the SFA operand
TensorDescription SFA;
/// Describes the SFB operand
TensorDescription SFB;
/// Describes the SFD operand
TensorDescription SFD;
/// Describes the data type of the scalars passed to the epilogue
NumericTypeID element_epilogue;
/// Describes the structure of parallel reductions
SplitKMode split_k_mode;
/// Transformation on A operand
ComplexTransform transform_A;
/// Transformation on B operand
ComplexTransform transform_B;
/// Describes the input ScaleFactor VectorSize
int SFVecSize;
/// Describes the Output ScaleFactor VectorSize
int EpilogueSFVecSize;
//
// Methods
//
BlockScaledGemmDescription(
GemmKind gemm_kind = GemmKind::kGemm,
TensorDescription const& A = TensorDescription(),
TensorDescription const& B = TensorDescription(),
TensorDescription const& C = TensorDescription(),
TensorDescription const& D = TensorDescription(),
NumericTypeID element_epilogue = NumericTypeID::kInvalid,
SplitKMode split_k_mode = SplitKMode::kNone,
ComplexTransform transform_A = ComplexTransform::kNone,
ComplexTransform transform_B = ComplexTransform::kNone
):
gemm_kind(gemm_kind),
A(A),
B(B),
C(C),
D(D),
element_epilogue(element_epilogue),
split_k_mode(split_k_mode),
transform_A(transform_A),
transform_B(transform_B) {}
BlockScaledGemmDescription(
OperationDescription op_desc,
GemmKind gemm_kind,
TensorDescription const& A,
TensorDescription const& B,
TensorDescription const& C,
TensorDescription const& D,
NumericTypeID element_epilogue,
SplitKMode split_k_mode,
ComplexTransform transform_A,
ComplexTransform transform_B
):
OperationDescription(op_desc),
gemm_kind(gemm_kind),
A(A),
B(B),
C(C),
D(D),
element_epilogue(element_epilogue),
split_k_mode(split_k_mode),
transform_A(transform_A),
transform_B(transform_B) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Description for structured sparse GEMMs.

View File

@ -178,6 +178,15 @@ public:
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
int cluster_m, /// cluster shape M dimension
int cluster_n, /// cluster shape N dimension
int cluster_k, /// cluster shape K dimension
int cluster_m_fallback, /// Fallback cluster shape M dimension
int cluster_n_fallback, /// Fallback cluster shape N dimension
int cluster_k_fallback, /// Fallback cluster shape K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars

View File

@ -103,6 +103,7 @@ public:
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
// Originally designed for metadata, but should be useful for FP8/6/4 too.
virtual Status initialize_with_profiler_workspace(
void const *configuration,
void *host_workspace,
@ -269,6 +270,8 @@ struct GemmUniversalConfiguration {
GemmUniversalMode mode{GemmUniversalMode::kGemm};
gemm::GemmCoord problem_size{};
gemm::GemmCoord cluster_shape{};
gemm::GemmCoord cluster_shape_fallback{};
int batch_count{1};
int64_t lda{0};
@ -282,6 +285,8 @@ struct GemmUniversalConfiguration {
struct GemmUniversalArguments {
// NOTE: these are replicated for 3.0 interfaces
gemm::GemmCoord problem_size{};
gemm::GemmCoord cluster_shape{};
gemm::GemmCoord cluster_shape_fallback{};
int batch_count{1};
void const *A{nullptr};
@ -307,13 +312,68 @@ struct GemmUniversalArguments {
// Needed for some 3.x kernels
int sm_count{0};
library::RasterOrder raster_order{};
library::RuntimeDatatype runtime_input_datatype_a{};
library::RuntimeDatatype runtime_input_datatype_b{};
int swizzle_size{1};
int split_k_slices{1};
int device_index{0};
bool use_pdl{false};
};
/// Block Scaled GEMM
//
// OperationKind: kBlockScaledGemm
// GemmKind: Universal
struct BlockScaledGemmArguments {
// NOTE: these are replicated for 3.0 interfaces
gemm::GemmCoord problem_size{};
gemm::GemmCoord cluster_shape{};
gemm::GemmCoord cluster_shape_fallback{};
int batch_count{1};
void const *A{nullptr};
void const *B{nullptr};
void const *SFA{nullptr};
void const *SFB{nullptr};
void const *C{nullptr};
void *D{nullptr};
void *SFD{nullptr};
void const *alpha{nullptr};
void const *beta{nullptr};
ScalarPointerMode pointer_mode{};
// NOTE: these are replicated for 3.0 interfaces
int64_t lda{0};
int64_t ldb{0};
int64_t ldc{0};
int64_t ldd{0};
int64_t batch_stride_A{0};
int64_t batch_stride_B{0};
int64_t batch_stride_C{0};
int64_t batch_stride_D{0};
// Needed for ScaleFactor Generation
void const *norm_constant{nullptr};
// Needed for some 3.x kernels
int sm_count{0};
library::RasterOrder raster_order{};
int swizzle_size{1};
int split_k_slices{1};
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
bool use_pdl{false};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Complex valued GEMM in which real and imaginary parts are separated by a stride

View File

@ -243,6 +243,191 @@ using GemmOperationFunctionalMap = std::unordered_map<
GemmFunctionalKeyHasher
>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Data Structures for BlockScaled Gemm Functional Maps
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tuple uniquely identifying Gemm functional behavior
struct BlockScaledGemmFunctionalKey {
Provider provider;
GemmKind gemm_kind;
OperationKind kind;
NumericTypeID element_compute;
NumericTypeID element_scalar;
NumericTypeID element_A;
LayoutTypeID layout_A;
NumericTypeID element_SFA;
NumericTypeID element_B;
LayoutTypeID layout_B;
NumericTypeID element_SFB;
NumericTypeID element_C;
LayoutTypeID layout_C;
NumericTypeID element_D;
LayoutTypeID layout_D;
NumericTypeID element_SFD;
LayoutTypeID layout_SFD;
int SFVecSize;
int EpilogueSFVecSize;
//
// Methods
//
inline
BlockScaledGemmFunctionalKey(
Provider provider,
GemmKind gemm_kind = GemmKind::kGemm,
OperationKind kind = OperationKind::kBlockScaledGemm,
NumericTypeID element_compute = NumericTypeID::kF32,
NumericTypeID element_scalar = NumericTypeID::kF32,
NumericTypeID element_A = NumericTypeID::kF16,
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
NumericTypeID element_SFA = NumericTypeID::kF16,
NumericTypeID element_B = NumericTypeID::kF16,
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
NumericTypeID element_SFB = NumericTypeID::kF16,
NumericTypeID element_C = NumericTypeID::kF16,
LayoutTypeID layout_C = LayoutTypeID::kColumnMajor,
NumericTypeID element_D = NumericTypeID::kF16,
LayoutTypeID layout_D = LayoutTypeID::kColumnMajor,
NumericTypeID element_SFD = NumericTypeID::kF16,
LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor,
int sf_vec_size = 32
, int epilogue_sf_vec_size = 32
):
provider(provider),
gemm_kind(gemm_kind),
kind(kind),
element_compute(element_compute),
element_scalar(element_scalar),
element_A(element_A),
layout_A(layout_A),
element_SFA(element_SFA),
element_B(element_B),
layout_B(layout_B),
element_SFB(element_SFB),
element_C(element_C),
layout_C(layout_C),
element_D(element_D),
layout_D(layout_D),
element_SFD(element_SFD),
layout_SFD(layout_SFD),
SFVecSize(sf_vec_size)
, EpilogueSFVecSize(epilogue_sf_vec_size)
{ }
inline
bool operator==(BlockScaledGemmFunctionalKey const &rhs) const {
return
(provider == rhs.provider) &&
(gemm_kind == rhs.gemm_kind) &&
(kind == rhs.kind) &&
(element_compute == rhs.element_compute) &&
(element_scalar == rhs.element_scalar) &&
(element_A == rhs.element_A) &&
(layout_A == rhs.layout_A) &&
(element_SFA == rhs.element_SFA) &&
(element_B == rhs.element_B) &&
(layout_B == rhs.layout_B) &&
(element_SFB == rhs.element_SFB) &&
(element_C == rhs.element_C) &&
(layout_C == rhs.layout_C) &&
(element_D == rhs.element_D) &&
(layout_D == rhs.layout_D) &&
(element_SFD == rhs.element_SFD) &&
(layout_SFD == rhs.layout_SFD) &&
(SFVecSize == rhs.SFVecSize)
&& (EpilogueSFVecSize == rhs.EpilogueSFVecSize)
;
}
inline
bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const {
return !(*this == rhs);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
inline
std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) {
out << "{\n"
<< " provider: " << to_string(k.provider) << "\n"
<< " gemm_kind: " << to_string(k.gemm_kind) << "\n"
<< " kind: " << to_string(k.kind) << "\n"
<< " element_compute: " << to_string(k.element_compute) << "\n"
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
<< " element_A: " << to_string(k.element_A) << "\n"
<< " layout_A: " << to_string(k.layout_A) << "\n"
<< " element_SFA: " << to_string(k.element_SFA) << "\n"
<< " element_B: " << to_string(k.element_B) << "\n"
<< " layout_B: " << to_string(k.layout_B) << "\n"
<< " element_SFB: " << to_string(k.element_SFB) << "\n"
<< " element_C: " << to_string(k.element_C) << "\n"
<< " layout_C: " << to_string(k.layout_C) << "\n"
<< " element_D: " << to_string(k.element_D) << "\n"
<< " layout_D: " << to_string(k.layout_D) << "\n"
<< " element_SFD: " << to_string(k.element_SFD) << "\n"
<< " layout_SFD: " << to_string(k.layout_SFD) << "\n"
<< " SFVecSize: " << k.SFVecSize << "\n"
<< "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n"
<< "}";
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Hash function for BlockScaledGemmFunctionalKeyHasher
struct BlockScaledGemmFunctionalKeyHasher {
using IntHash = std::hash<int>;
inline
static size_t rotl(size_t key, int shl) {
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
}
inline
size_t operator()(BlockScaledGemmFunctionalKey const &key) const {
IntHash hash;
return
rotl(hash(int(key.provider)), 1) ^
rotl(hash(int(key.gemm_kind)), 2) ^
rotl(hash(int(key.kind)), 3) ^
rotl(hash(int(key.element_compute)), 4) ^
rotl(hash(int(key.element_scalar)), 5) ^
rotl(hash(int(key.element_A)), 6) ^
rotl(hash(int(key.layout_A)), 7) ^
rotl(hash(int(key.element_SFA)), 8) ^
rotl(hash(int(key.element_B)), 9) ^
rotl(hash(int(key.layout_B)), 10) ^
rotl(hash(int(key.element_SFB)), 11) ^
rotl(hash(int(key.element_C)), 12) ^
rotl(hash(int(key.layout_C)), 13) ^
rotl(hash(int(key.element_D)), 14) ^
rotl(hash(int(key.layout_D)), 15) ^
rotl(hash(int(key.element_SFD)), 16) ^
rotl(hash(int(key.layout_SFD)), 17) ^
rotl(hash(int(key.SFVecSize)), 18) ^
rotl(hash(int(key.EpilogueSFVecSize)), 19)
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
using BlockScaledGemmOperationFunctionalMap = std::unordered_map<
BlockScaledGemmFunctionalKey,
GemmOperationVectorMap,
BlockScaledGemmFunctionalKeyHasher
>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Data Structures for Conv Functional Maps
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -509,6 +694,9 @@ public:
// provider (kCUTLASS)
GemmOperationFunctionalMap gemm_operations;
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations;
/// Map of all operations of type kConv2d
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
ConvOperationFunctionalMap conv2d_operations;

View File

@ -43,6 +43,7 @@ enum class LayoutTypeID {
kUnknown,
kColumnMajor,
kRowMajor,
kBlockScalingTensor,
kColumnMajorInterleavedK2,
kRowMajorInterleavedK2,
kColumnMajorInterleavedK4,
@ -83,6 +84,16 @@ enum class NumericTypeID {
kS64,
kFE4M3,
kFE5M2,
kFE2M3,
kFE3M2,
kFE2M1,
kFUE8M0,
kFUE4M3,
kF8,
kF6,
kF4,
kF16,
kBF16,
kTF32,
@ -131,6 +142,7 @@ enum class Provider {
/// Enumeration indicating the kind of operation
enum class OperationKind {
kGemm,
kBlockScaledGemm,
kRankK,
kRank2K,
kTrmm,
@ -165,6 +177,7 @@ enum class OpcodeClassID {
kTensorOp,
kWmmaTensorOp,
kSparseTensorOp,
kBlockScaledOp,
kInvalid
};
@ -188,6 +201,7 @@ enum class MathOperationID {
/// Enumeration indicating what kind of GEMM operation to perform
enum class GemmKind {
kGemm,
kBlockScaledGemm,
kSparse,
kUniversal,
kPlanarComplex,
@ -251,6 +265,20 @@ enum class EpilogueKind {
kInvalid
};
enum class RuntimeDatatype {
kStatic,
kE4M3,
kE5M2,
kE3M2,
kE2M3,
kE2M1,
kInvalid
};
enum class RasterOrder {
kAlongN,
kAlongM,

View File

@ -170,6 +170,15 @@ char const *to_string(ConvKind type, bool pretty = false);
template <>
ConvKind from_string<ConvKind>(std::string const &str);
/// Converts a RuntimeDatatype enumerant to a string
char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false);
/// Convers a RuntimeDatatype enumerant from a string
template<>
cutlass::library::RuntimeDatatype from_string<cutlass::library::RuntimeDatatype>(std::string const &str);
/// Converts a RasterOrder enumerant to a string
char const *to_string(RasterOrder type, bool pretty = false);
@ -202,6 +211,8 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library

View File

@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines operations for all GEMM operation kinds in CUTLASS Library.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/library/library.h"
#include "library_internal.h"
#include "gemm_operation_3x.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementA = typename Operator::CollectiveMainloop::ElementA;
using ElementSFA = typename Operator::CollectiveMainloop::ElementSF;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::CollectiveMainloop::ElementB;
using ElementSFB = typename Operator::CollectiveMainloop::ElementSF;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
constexpr static int SFVecSize = TiledMma::SFVecSize;
using CollectiveMainloop = typename Operator::CollectiveMainloop;
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig;
static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v<typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize;
using ElementSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
using LayoutSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::GmemLayoutTagScalefactor, LayoutD>;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA;
using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB;
private:
BlockScaledGemmDescription description_;
public:
/// Constructor
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
description_.kind = OperationKind::kBlockScaledGemm;
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
description_.SFA.layout = LayoutTypeID::kRowMajor;
description_.SFA.alignment = 128;
description_.SFA.log_extent_range = 32;
description_.SFA.log_stride_range = 32;
description_.SFB.element = NumericTypeMap<ElementSFB>::kId;
description_.SFB.layout = LayoutTypeID::kRowMajor;
description_.SFB.alignment = 128;
description_.SFB.log_extent_range = 32;
description_.SFB.log_stride_range = 32;
description_.SFVecSize = SFVecSize;
description_.SFD = make_TensorDescription<ElementSFD, LayoutSFD>(128);
description_.EpilogueSFVecSize = SFD_VectorSize;
description_.name = name;
description_.provider = Provider::kCUTLASS;
description_.gemm_kind = GemmKind::kUniversal;
description_.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM,
Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
description_.tile_description.cluster_shape = make_Coord(
Operator::ClusterShape::kM,
Operator::ClusterShape::kN,
Operator::ClusterShape::kK);
}
description_.tile_description.threadblock_stages = Operator::kStages;
description_.tile_description.warp_count = make_Coord(
Operator::WarpCount::kM,
Operator::WarpCount::kN,
Operator::WarpCount::kK);
description_.tile_description.math_instruction.instruction_shape = make_Coord(
Operator::InstructionShape::kM,
Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
description_.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
description_.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::MathOperator>::kId;
description_.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
description_.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
description_.D = make_TensorDescription<ElementD, LayoutD>(Operator::kAlignmentD);
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.split_k_mode = SplitKMode::kNone;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
/// Returns the description of the GEMM operation
BlockScaledGemmDescription const& get_gemm_description() const {
return description_;
}
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
if constexpr (epilogue_scalefactor_generation) {
fusion_args.block_scale_factor_ptr = static_cast<ElementSFD*>(arguments.SFD);
fusion_args.norm_constant_ptr = static_cast<ElementCompute const *>(arguments.norm_constant);
}
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
return Status::kSuccess;
}
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
return Status::kSuccess;
}
else {
return Status::kErrorInvalidProblem;
}
}
};
/// Constructs the arguments structure given the configuration and arguments
static Status update_arguments_(
OperatorArguments &operator_args,
BlockScaledGemmArguments const *arguments) {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
operator_args.problem_shape = cute::make_shape(
arguments->problem_size.m(),
arguments->problem_size.n(),
arguments->problem_size.k(),
arguments->batch_count);
// update arguments
if constexpr (IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA;
using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB;
static_assert(cute::is_same_v<RuntimeDataTypeA, RuntimeDataTypeB>,
"RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format");
using RuntimeDatatypeArg = RuntimeDataTypeA;
auto mapping = [](RuntimeDatatype type) {
if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF8F6F4Format>) {
if (type == RuntimeDatatype::kE3M2) {
return cute::UMMA::MXF8F6F4Format::E3M2;
} else if (type == RuntimeDatatype::kE2M3) {
return cute::UMMA::MXF8F6F4Format::E2M3;
} else if (type == RuntimeDatatype::kE2M1) {
return cute::UMMA::MXF8F6F4Format::E2M1;
} else {
assert("Invalid input datatype.");
}
}
else if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF4Format>) {
if (type == RuntimeDatatype::kE2M1) {
return cute::UMMA::MXF4Format::E2M1;
} else {
assert("Invalid input datatype.");
}
}
// BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype
CUTE_GCC_UNREACHABLE;
};
operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a);
operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b);
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
}
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
arguments->ldb, arguments->batch_stride_B);
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
arguments->ldc, arguments->batch_stride_C);
operator_args.epilogue.dD = operator_args.epilogue.dC;
operator_args.mainloop.layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape);
operator_args.mainloop.layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape);
/* Query device SM count to pass onto the kernel as an argument, where needed */
operator_args.hw_info.sm_count = arguments->sm_count;
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
}
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
using Enum_t = decltype(operator_args.scheduler.raster_order);
switch (arguments->raster_order) {
case RasterOrder::kAlongN:
operator_args.scheduler.raster_order = Enum_t::AlongN;
break;
case RasterOrder::kAlongM:
operator_args.scheduler.raster_order = Enum_t::AlongM;
break;
default:
operator_args.scheduler.raster_order = Enum_t::Heuristic;
}
}
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
operator_args.scheduler.splits = arguments->split_k_slices;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape = dim3(
arguments->cluster_shape.m(),
arguments->cluster_shape.n(),
arguments->cluster_shape.k());
operator_args.hw_info.cluster_shape_fallback = dim3(
arguments->cluster_shape_fallback.m(),
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return status;
}
public:
/// Returns success if the operation can proceed
Status can_implement(
void const *configuration_ptr, void const *arguments_ptr) const override {
GemmUniversalConfiguration const *configuration =
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
BlockScaledGemmArguments const *arguments =
static_cast<BlockScaledGemmArguments const *>(arguments_ptr);
OperatorArguments args;
auto status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
// can_implement rules may need access to problem shape
args.problem_shape = cute::make_shape(
configuration->problem_size.m(),
configuration->problem_size.n(),
configuration->problem_size.k(),
configuration->batch_count);
return Operator::can_implement(args);
}
/// Gets the host-side workspace
uint64_t get_host_workspace_size(void const *configuration) const override {
return sizeof(Operator);
}
/// Gets the device-side workspace
uint64_t get_device_workspace_size(
void const *configuration_ptr,void const *arguments_ptr) const override {
OperatorArguments args;
auto status = update_arguments_(
args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const override {
Operator *op = new (host_workspace) Operator;
return Status::kSuccess;
}
Status initialize_with_profiler_workspace(
void const *configuration,
void *host_workspace,
void *device_workspace,
uint8_t **profiler_workspaces,
int problem_count_from_profiler,
cudaStream_t stream = nullptr) {
return Status::kSuccess;
}
/// Runs the kernel
Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments args;
Status status = update_arguments_(args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
status = op->run(args, device_workspace, stream, nullptr, static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->use_pdl);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::library
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -162,6 +162,18 @@ public:
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
public:
/// Constructor
@ -235,8 +247,42 @@ protected:
arguments->batch_count);
// update arguments
if constexpr (IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
std::unordered_map<RuntimeDatatype, cute::UMMA::MXF8F6F4Format> mapping = {
{RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3},
{RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2},
{RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2},
{RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1}
};
auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a);
auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b);
if (iter_runtime_a != mapping.end()) {
operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second;
} else {
assert("invalid runtime argument for datatype A!");
}
if (iter_runtime_b != mapping.end()) {
operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second;
} else {
assert("invalid runtime argument for datatype B!");
}
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
}
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
@ -277,6 +323,22 @@ protected:
}
}
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
operator_args.scheduler.splits = arguments->split_k_slices;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape = dim3(
arguments->cluster_shape.m(),
arguments->cluster_shape.n(),
arguments->cluster_shape.k());
operator_args.hw_info.cluster_shape_fallback = dim3(
arguments->cluster_shape_fallback.m(),
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return status;
}

View File

@ -510,6 +510,15 @@ Status Handle::gemm_universal(
int M, /// GEMM M dimension
int N, /// GEMM N dimension
int K, /// GEMM K dimension
int cluster_m, /// cluster shape M dimension
int cluster_n, /// cluster shape N dimension
int cluster_k, /// cluster shape K dimension
int cluster_m_fallback, /// Fallback cluster shape M dimension
int cluster_n_fallback, /// Fallback cluster shape N dimension
int cluster_k_fallback, /// Fallback cluster shape K dimension
NumericTypeID element_compute, /// Data type of internal accumulation
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
@ -629,6 +638,8 @@ Status Handle::gemm_universal(
GemmUniversalConfiguration configuration{
mode,
{M, N, K},
{cluster_m, cluster_n, cluster_k},
{cluster_m_fallback, cluster_n_fallback, cluster_k_fallback},
batch_count,
lda,
ldb,
@ -647,6 +658,8 @@ Status Handle::gemm_universal(
GemmUniversalArguments arguments{
{M, N, K},
{cluster_m, cluster_n, cluster_k},
{cluster_m_fallback, cluster_n_fallback, cluster_k_fallback},
batch_count,
ptr_A,
ptr_B,

View File

@ -116,6 +116,27 @@ template <> struct NumericTypeMap<cutlass::float_e5m2_t> {
static NumericTypeID const kId = NumericTypeID::kFE5M2;
};
template <> struct NumericTypeMap<cutlass::float_e2m3_t> {
static NumericTypeID const kId = NumericTypeID::kFE2M3;
};
template <> struct NumericTypeMap<cutlass::float_e3m2_t> {
static NumericTypeID const kId = NumericTypeID::kFE3M2;
};
template <> struct NumericTypeMap<cutlass::float_e2m1_t> {
static NumericTypeID const kId = NumericTypeID::kFE2M1;
};
template <> struct NumericTypeMap<cutlass::float_ue8m0_t> {
static NumericTypeID const kId = NumericTypeID::kFUE8M0;
};
template <> struct NumericTypeMap<cutlass::float_ue4m3_t> {
static NumericTypeID const kId = NumericTypeID::kFUE4M3;
};
template <> struct NumericTypeMap<uint16_t> {
static NumericTypeID const kId = NumericTypeID::kU16;
};
@ -161,6 +182,21 @@ template <> struct NumericTypeMap<cutlass::tfloat32_t> {
};
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float8_t> {
static NumericTypeID const kId = NumericTypeID::kF8;
};
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float6_t> {
static NumericTypeID const kId = NumericTypeID::kF6;
};
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float4_t> {
static NumericTypeID const kId = NumericTypeID::kF4;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> struct MathOperationMap {
@ -300,6 +336,12 @@ template <> struct OpcodeClassMap<arch::OpClassSparseTensorOp> {
static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp;
};
template <> struct OpcodeClassMap<arch::OpClassBlockScaledTensorOp> {
static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp;
};
template <> struct OpcodeClassMap<arch::OpClassWmmaTensorOp> {
static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp;
};

View File

@ -48,6 +48,45 @@ void OperationTable::append(Manifest const &manifest) {
// Insert operations into appropriate data structure
for (auto const & operation : manifest) {
OperationDescription const &desc = operation->description();
if (desc.kind == OperationKind::kBlockScaledGemm) {
BlockScaledGemmDescription const &gemm_desc = static_cast<BlockScaledGemmDescription const &>(desc);
BlockScaledGemmFunctionalKey functional_key(
gemm_desc.provider,
gemm_desc.gemm_kind,
gemm_desc.kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.SFA.element,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.SFB.element,
gemm_desc.C.element,
gemm_desc.C.layout,
gemm_desc.D.element,
gemm_desc.D.layout,
gemm_desc.SFD.element,
gemm_desc.SFD.layout,
gemm_desc.SFVecSize
, gemm_desc.EpilogueSFVecSize
);
Operation const *op = operation.get();
int cc = gemm_desc.tile_description.minimum_compute_capability;
int alignment = std::max(std::max(
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
GemmPreferenceKey preference_key(cc, alignment);
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
}
// insert all gemm operation into operation table
if (desc.kind == OperationKind::kGemm) {
GemmDescription const &gemm_desc = static_cast<GemmDescription const &>(desc);

View File

@ -0,0 +1,128 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "block_scaled_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest) {
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// SFVectorSize = 16 with MxF4NvF4 instructions
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// (float_e2m1_t * float_ue4m3_t) * (float_e2m1_t * float_ue4m3_t)
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
16 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
// (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
16 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
16 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm_tn<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,130 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "block_scaled_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest) {
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// SFVectorSize = 32 with MxF4 instructions
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// With SF generation reference
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
16 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
16 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpiSFVecSize*/
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,354 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "block_scaled_gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest) {
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// SFVectorSize = 32 with MxF8F6F4 instructions
//////////////////////////////////////////////////////////////////////////////////////////////////////////
// (float_e2m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// (float_e4m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// (float_e2m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// (float_e2m1_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// (float_e4m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
// (float_e4m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
// (float_e3m2_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
// (float_e2m1_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
// (float_e2m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
make_block_scaled_gemm<
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
32 /*EpilogueSFVecSize*/
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,459 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library
*/
#pragma once
#include <iostream>
#include <sstream>
#include <cstring>
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "cutlass/util/packed_stride.hpp"
#include "library_internal.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
namespace detail {
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
Provider Provider_,
typename ElementA_,
typename LayoutA_,
typename ElementSFA_,
typename ElementB_,
typename LayoutB_,
typename ElementSFB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
typename ElementSFD_ = void,
typename LayoutSFD_ = LayoutC_,
int SFVecSize_ = 32,
int EpilogueSFVecSize_ = 0,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
class BlockScaledGemmReferenceOperation : public Operation {
public:
static Provider const kProvider = Provider_;
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using ElementSFA = ElementSFA_;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using ElementSFB = ElementSFB_;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using ElementD = ElementD_;
using ElementSFD = ElementSFD_;
using LayoutSFD = LayoutSFD_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementAccumulator_;
using ConvertOp = ConvertOp_;
using InnerProductOp = InnerProductOp_;
constexpr static int SFVecSize = SFVecSize_;
constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_;
protected:
/// Storage for the name string
std::string name_;
///
BlockScaledGemmDescription description_;
public:
/// Constructor
BlockScaledGemmReferenceOperation() {
// Basic information
description_.provider = kProvider;
description_.kind = OperationKind::kBlockScaledGemm;
description_.gemm_kind = GemmKind::kUniversal;
// Tensor description
description_.A = make_TensorDescription<ElementA, LayoutA>();
description_.SFA = make_TensorDescription<ElementSFA, LayoutA>();
description_.B = make_TensorDescription<ElementB, LayoutB>();
description_.SFB = make_TensorDescription<ElementSFB, LayoutB>();
description_.C = make_TensorDescription<ElementC, LayoutC>();
description_.D = make_TensorDescription<ElementD, LayoutC>();
description_.SFD = make_TensorDescription<ElementSFD, LayoutSFD>();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
// Compute capability for gemm reference
description_.tile_description.minimum_compute_capability =
(kProvider == Provider::kReferenceDevice ? 50 : 0);
description_.tile_description.maximum_compute_capability = 1024;
description_.SFVecSize = SFVecSize;
description_.EpilogueSFVecSize = EpilogueSFVecSize;
// Procedural name
std::stringstream ss;
ss << "gemm"
<< "_reference_" << to_string(description_.provider)
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
<< "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout)
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
<< "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout)
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
<< "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout)
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
name_ = ss.str();
description_.name = name_.c_str();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
virtual Status can_implement(
void const *configuration,
void const *arguments) const {
return Status::kSuccess;
}
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(GemmUniversalConfiguration);
}
virtual uint64_t get_device_workspace_size(
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
return Status::kSuccess;
}
virtual Status run(
void const *arguments,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
using namespace cute;
BlockScaledGemmArguments const &args = *static_cast<BlockScaledGemmArguments const *>(arguments);
// Construct cute::Tensor A/B/C
int M = args.problem_size.m();
int N = args.problem_size.n();
int K = args.problem_size.k();
int L = args.batch_count;
auto problem_shape_MNKL = cute::make_shape(M, N, K, L);
auto alpha = *(static_cast<ElementCompute const*>(args.alpha));
auto beta = *(static_cast<ElementCompute const*>(args.beta));
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutC>;
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
using Sm100BlockScaledConfig = cutlass::detail::Sm100BlockScaledConfig<SFVecSize>;
auto A = cute::make_tensor(detail::make_iterator(static_cast<ElementA const*>(args.A)),
cute::make_layout(cute::make_shape(M, K, L), stride_a));
auto SfA = make_tensor(static_cast<ElementSFA const*>(args.SFA), Sm100BlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL));
auto B = cute::make_tensor(detail::make_iterator(static_cast<ElementB const*>(args.B)),
cute::make_layout(cute::make_shape(N, K, L), stride_b));
auto SfB = make_tensor(static_cast<ElementSFB const*>(args.SFB), Sm100BlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL));
auto C = [&]() {
if constexpr (not is_same_v<ElementC, void>) {
return cute::make_tensor(detail::make_iterator(static_cast<ElementC const*>(args.C)),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
else {
return cute::make_tensor(detail::make_iterator(static_cast<ElementD const*>(nullptr)),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
}();
auto D = cute::make_tensor(detail::make_iterator(static_cast<ElementD *>(args.D)),
cute::make_layout(cute::make_shape(M, N, L), stride_d));
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
decltype(A), decltype(SfA),
decltype(B), decltype(SfB)>
mainloop_params{A, SfA, B, SfB};
if constexpr (not is_same_v<ElementSFD, void>) {
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
EpilogueSFVecSize
>;
auto SfD = cute::make_tensor(detail::make_iterator(static_cast<ElementSFD*>(args.SFD)), Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL));
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, ElementAccumulator, ElementCompute,
decltype(C), decltype(D), decltype(SfD), Int<EpilogueSFVecSize>, cutlass::reference::host::SfStrategy::SfDGen>
epilogue_params{alpha, beta, C, D, SfD, *(static_cast<ElementCompute const*>(args.norm_constant))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
else {
// W/O SF generation
auto SfD = cute::make_tensor(static_cast<ElementSFA *>(nullptr),
cute::make_layout(cute::make_shape(M, N, L))); // not used.
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, ElementAccumulator, ElementCompute,
decltype(C), decltype(D), decltype(SfD)>
epilogue_params{alpha, beta, C, D, SfD};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
return Status::kSuccess;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_,
typename ElementSFA_,
typename ElementB_,
typename ElementSFB_,
typename ElementC_,
typename ElementCompute_,
typename ElementSFD_ = void,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
int SFVecSize = 32,
int EpilogueSFVecSize = SFVecSize,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_block_scaled_gemm_tn(Manifest &manifest) {
#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE)
manifest.append(new BlockScaledGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ElementSFD_,
cutlass::layout::RowMajor,
SFVecSize,
EpilogueSFVecSize,
ConvertOp_,
InnerProductOp_
>);
#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE)
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_,
typename ElementSFA_,
typename ElementB_,
typename ElementSFB_,
typename ElementC_,
typename ElementCompute_,
typename ElementSFD_ = void,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
int SFVecSize = 32,
int EpilogueSFVecSize = SFVecSize,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_block_scaled_gemm(Manifest &manifest) {
///
/// A is Row , B is Col
///
manifest.append(new BlockScaledGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ElementSFD_,
cutlass::layout::RowMajor,
SFVecSize,
EpilogueSFVecSize,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new BlockScaledGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ElementSFD_,
cutlass::layout::RowMajor,
SFVecSize,
EpilogueSFVecSize,
ConvertOp_,
InnerProductOp_
>);
///
/// A is Col , B is Row
///
manifest.append(new BlockScaledGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ElementSFD_,
cutlass::layout::RowMajor,
SFVecSize,
EpilogueSFVecSize,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new BlockScaledGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ElementSFD_,
cutlass::layout::RowMajor,
SFVecSize,
EpilogueSFVecSize,
ConvertOp_,
InnerProductOp_
>);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,109 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A/B : float_e2m1_t (not support float_e0m2_t to reduce ref kernel compile time)
// Acc: f32
// C/D : some variance
// 1. e2m1_e2m1_f32_f16_e4m3
// 2. e2m1_e2m1_f32_f16_e5m2
// 3. e2m1_e2m1_f32_f16_f16
// 4. e2m1_e2m1_f32_f32_f32
void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e2m1_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e2m1_t
// B: float_e3m2_t
// Acc: f32
// C/D : some variance
// 1. e2m1_e3m2_f32_f16_e4m3
// 2. e2m1_e3m2_f32_f16_e5m2
// 3. e2m1_e3m2_f32_f16_f16
// 4. e2m1_e3m2_f32_f32_f32
void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e3m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e2m1_t
// B: float_e4m3_t
// Acc: f32
// C/D : some variance
// 1. e2m1_e4m3_f32_f16_e4m3
// 2. e2m1_e4m3_f32_f16_e5m2
// 3. e2m1_e4m3_f32_f16_f16
// 4. e2m1_e4m3_f32_f32_f32
void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e2m1_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e3m2_t
// B: float_e2m1_t
// Acc: f32
// C/D : some variance
// 1. e3m2_e2m1_f32_f16_e4m3
// 2. e3m2_e2m1_f32_f16_e5m2
// 3. e3m2_e2m1_f32_f16_f16
// 4. e3m2_e2m1_f32_f32_f32
void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e2m1_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,109 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A/B : float_e3m2_t (not support float_e2m3_t to reduce ref kernel compile time)
// Acc: f32
// C/D : some variance
// 1. e3m2_e3m2_f32_f16_e4m3
// 2. e3m2_e3m2_f32_f16_e5m2
// 3. e3m2_e3m2_f32_f16_f16
// 4. e3m2_e3m2_f32_f32_f32
void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e3m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e3m2_t
// B: float_e4m3_t
// Acc: f32
// C/D : some variance
// 1. e3m2_e4m3_f32_f16_e4m3
// 2. e3m2_e4m3_f32_f16_e5m2
// 3. e3m2_e4m3_f32_f16_f16
// 4. e3m2_e4m3_f32_f32_f32
void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e4m3_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e3m2_t, // ElementA
float_e4m3_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e4m3_t
// B: float_e2m1_t
// Acc: f32
// C/D : some variance
// 1. e4m3_e2m1_f32_f16_e4m3
// 2. e4m3_e2m1_f32_f16_e5m2
// 3. e4m3_e2m1_f32_f16_f16
// 4. e4m3_e2m1_f32_f32_f32
void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e2m1_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e2m1_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations for FP8.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// A: float_e4m3_t
// B: float_e3m2_t
// Acc: f32
// C/D : some variance
// 1. e4m3_e3m2_f32_f16_e4m3
// 2. e4m3_e3m2_f32_f16_e5m2
// 3. e4m3_e3m2_f32_f16_f16
// 4. e4m3_e3m2_f32_f32_f32
void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest) {
// 1.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e4m3_t // ElementD
>(manifest);
// 2.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float_e5m2_t // ElementD
>(manifest);
// 3.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e3m2_t, // ElementB
half_t, // ElementC
float, // ElementScalar
float, // ElementAccumulator
half_t // ElementD
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
float_e4m3_t, // ElementA
float_e3m2_t, // ElementB
float, // ElementC
float, // ElementScalar
float, // ElementAccumulator
float // ElementD
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -87,6 +87,17 @@ void initialize_gemm_reference_operations_u8_u8_s32(Manifest &manifest) {
NumericConverterClamp<int8_t, float> // From Scalar to D
>(manifest);
// 4.
make_gemm_real_canonical_layouts<
uint8_t, // ElementA
uint8_t, // ElementB
int8_t, // ElementC
float, // ElementScalar / ElementCompute
int32_t, // ElementAccumulator
uint8_t, // ElementD
NumericConverterClamp<uint8_t, float> // From Scalar to D
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -52,6 +52,19 @@ void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest);
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest);
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest);
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest);
void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest);
void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
@ -89,6 +102,19 @@ void initialize_reference_operations(Manifest &manifest) {
initialize_gemm_reference_operations_fp_mixed_input(manifest);
initialize_gemm_reference_operations_int_mixed_input(manifest);
initialize_gemm_reference_operations_f4_f4_f32(manifest);
initialize_gemm_reference_operations_f4_f6_f32(manifest);
initialize_gemm_reference_operations_f4_f8_f32(manifest);
initialize_gemm_reference_operations_f6_f4_f32(manifest);
initialize_gemm_reference_operations_f6_f6_f32(manifest);
initialize_gemm_reference_operations_f6_f8_f32(manifest);
initialize_gemm_reference_operations_f8_f4_f32(manifest);
initialize_gemm_reference_operations_f8_f6_f32(manifest);
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -211,6 +211,10 @@ protected:
}
}
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
operator_args.scheduler.splits = arguments->split_k_slices;
}
return status;
}

View File

@ -334,6 +334,7 @@ static struct {
OperationKind_enumerants[] = {
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
{"gemm", "Gemm", OperationKind::kGemm},
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
{"rank_k", "RankK", OperationKind::kRankK},
{"rank_2k", "Rank2K", OperationKind::kRank2K},
{"trmm", "Trmm", OperationKind::kTrmm},
@ -422,6 +423,53 @@ Status from_string<Status>(std::string const &str) {
///////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
char const *text;
char const *pretty;
RuntimeDatatype enumerant;
}
RuntimeDatatype_enumerants[] = {
{"e4m3", "<e4m3>", RuntimeDatatype::kE4M3},
{"e5m2", "<e5m2>", RuntimeDatatype::kE5M2},
{"e3m2", "<e3m2>", RuntimeDatatype::kE3M2},
{"e2m3", "<e2m3>", RuntimeDatatype::kE2M3},
{"e2m1", "<e2m1>", RuntimeDatatype::kE2M1}
};
/// Converts a RuntimeDatatype enumerant to a string
char const *to_string(RuntimeDatatype type, bool pretty) {
for (auto const & possible : RuntimeDatatype_enumerants) {
if (type == possible.enumerant) {
if (pretty) {
return possible.pretty;
}
else {
return possible.text;
}
}
}
return pretty ? "Invalid" : "invalid";
}
/// Converts a RuntimeDatatype enumerant from a string
template <>
RuntimeDatatype from_string<RuntimeDatatype>(std::string const &str) {
for (auto const & possible : RuntimeDatatype_enumerants) {
if ((str.compare(possible.text) == 0) ||
(str.compare(possible.pretty) == 0)) {
return possible.enumerant;
}
}
return RuntimeDatatype::kInvalid;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
static struct {
@ -447,6 +495,16 @@ NumericTypeID_enumerants[] = {
{"s64", "S64", NumericTypeID::kS64},
{"fe4m3", "FE4M3", NumericTypeID::kFE4M3},
{"fe5m2", "FE5M2", NumericTypeID::kFE5M2},
{"f8", "F8", NumericTypeID::kF8},
{"f6", "F6", NumericTypeID::kF6},
{"f4", "F4", NumericTypeID::kF4},
{"fe2m3", "FE2M3", NumericTypeID::kFE2M3},
{"fe3m2", "FE3M2", NumericTypeID::kFE3M2},
{"fe2m1", "FE2M1", NumericTypeID::kFE2M1},
{"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0},
{"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3},
{"f16", "F16", NumericTypeID::kF16},
{"bf16", "BF16", NumericTypeID::kBF16},
{"f32", "F32", NumericTypeID::kF32},
@ -510,6 +568,16 @@ int sizeof_bits(NumericTypeID type) {
switch (type) {
case NumericTypeID::kFE4M3: return 8;
case NumericTypeID::kFE5M2: return 8;
case NumericTypeID::kF8: return 8;
case NumericTypeID::kF6: return 6;
case NumericTypeID::kF4: return 4;
case NumericTypeID::kFE2M3: return 6;
case NumericTypeID::kFE3M2: return 6;
case NumericTypeID::kFE2M1: return 4;
case NumericTypeID::kFUE8M0: return 8;
case NumericTypeID::kFUE4M3: return 8;
case NumericTypeID::kF16: return 16;
case NumericTypeID::kBF16: return 16;
case NumericTypeID::kTF32: return 32;
@ -589,6 +657,16 @@ bool is_signed_type(NumericTypeID type) {
switch (type) {
case NumericTypeID::kFE4M3: return true;
case NumericTypeID::kFE5M2: return true;
case NumericTypeID::kF8: return true;
case NumericTypeID::kF6: return true;
case NumericTypeID::kF4: return true;
case NumericTypeID::kFE2M3: return true;
case NumericTypeID::kFE3M2: return true;
case NumericTypeID::kFE2M1: return true;
case NumericTypeID::kFUE8M0: return false;
case NumericTypeID::kFUE4M3: return false;
case NumericTypeID::kF16: return true;
case NumericTypeID::kBF16: return true;
case NumericTypeID::kTF32: return true;
@ -620,6 +698,16 @@ bool is_float_type(NumericTypeID type) {
switch (type) {
case NumericTypeID::kFE4M3: return true;
case NumericTypeID::kFE5M2: return true;
case NumericTypeID::kF8: return true;
case NumericTypeID::kF6: return true;
case NumericTypeID::kF4: return true;
case NumericTypeID::kFE2M3: return true;
case NumericTypeID::kFE3M2: return true;
case NumericTypeID::kFE2M1: return true;
case NumericTypeID::kFUE8M0: return true;
case NumericTypeID::kFUE4M3: return true;
case NumericTypeID::kF16: return true;
case NumericTypeID::kBF16: return true;
case NumericTypeID::kTF32: return true;
@ -1168,6 +1256,43 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(tmp);
}
break;
case NumericTypeID::kFE2M3:
{
float tmp;
ss >> tmp;
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(tmp);
}
break;
case NumericTypeID::kFE3M2:
{
float tmp;
ss >> tmp;
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(tmp);
}
break;
case NumericTypeID::kFE2M1:
{
float tmp;
ss >> tmp;
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(tmp);
}
break;
case NumericTypeID::kFUE8M0:
{
float tmp;
ss >> tmp;
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(tmp);
}
break;
case NumericTypeID::kFUE4M3:
{
float tmp;
ss >> tmp;
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(tmp);
}
break;
case NumericTypeID::kF16:
{
float tmp;
@ -1317,6 +1442,38 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
ss << tmp;
}
break;
case NumericTypeID::kFE2M3:
{
float tmp = *reinterpret_cast<float_e2m3_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kFE3M2:
{
float tmp = *reinterpret_cast<float_e3m2_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kFE2M1:
{
float tmp = *reinterpret_cast<float_e2m1_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kFUE8M0:
{
float tmp = *reinterpret_cast<float_ue8m0_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kFUE4M3:
{
float tmp = *reinterpret_cast<float_ue4m3_t *>(bytes.data());
ss << tmp;
}
break;
case NumericTypeID::kF16:
{
float tmp = *reinterpret_cast<half_t *>(bytes.data());
@ -1469,6 +1626,33 @@ bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t sr
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M3:
{
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
}
break;
case NumericTypeID::kFE3M2:
{
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M1:
{
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
}
break;
case NumericTypeID::kFUE8M0:
{
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
}
break;
case NumericTypeID::kFUE4M3:
{
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
@ -1579,6 +1763,33 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M3:
{
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
}
break;
case NumericTypeID::kFE3M2:
{
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M1:
{
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
}
break;
case NumericTypeID::kFUE8M0:
{
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
}
break;
case NumericTypeID::kFUE4M3:
{
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
@ -1690,6 +1901,33 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M3:
{
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
}
break;
case NumericTypeID::kFE3M2:
{
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
}
break;
case NumericTypeID::kFE2M1:
{
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
}
break;
case NumericTypeID::kFUE8M0:
{
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
}
break;
case NumericTypeID::kFUE4M3:
{
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
}
break;
case NumericTypeID::kF16:
{
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
@ -1751,6 +1989,35 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
return true;
}
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type) {
NumericTypeID element{};
switch (type) {
case RuntimeDatatype::kE4M3:
element = NumericTypeID::kFE4M3;
break;
case RuntimeDatatype::kE5M2:
element = NumericTypeID::kFE5M2;
break;
case RuntimeDatatype::kE2M3:
element = NumericTypeID::kFE2M3;
break;
case RuntimeDatatype::kE3M2:
element = NumericTypeID::kFE3M2;
break;
case RuntimeDatatype::kE2M1:
element = NumericTypeID::kFE2M1;
break;
default:
assert("illegal runtime datatype!");
break;
}
return element;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library

View File

@ -46,6 +46,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
src/problem_space.cpp
src/operation_profiler.cu
src/gemm_operation_profiler.cu
src/block_scaled_gemm_operation_profiler.cu
src/rank_k_operation_profiler.cu
src/rank_2k_operation_profiler.cu
src/trmm_operation_profiler.cu
@ -101,6 +102,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
else()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
endif()
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true)

View File

@ -0,0 +1,290 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Blockscale Gemm Profiler
*/
#pragma once
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <unordered_map>
// CUTLASS Library includes
#include "cutlass/library/library.h"
#include "cutlass/library/util.h"
#include "cutlass/library/manifest.h"
// Profiler includes
#include "options.h"
#include "device_context.h"
#include "operation_profiler.h"
#include "performance_result.h"
#include "problem_space.h"
#include "reduction_operation_profiler.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Abstract base class for each math function
class BlockScaledGemmOperationProfiler : public OperationProfiler {
public:
/// Problem structure obtained from problem space
struct GemmProblem {
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm};
int64_t m{16};
int64_t n{16};
int64_t k{16};
int cluster_m{1};
int cluster_n{1};
int cluster_k{1};
int cluster_m_fallback{1};
int cluster_n_fallback{1};
int cluster_k_fallback{1};
int64_t lda{0};
int64_t ldb{0};
int64_t ldc{0};
std::vector<uint8_t> alpha;
std::vector<uint8_t> beta;
cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone};
int split_k_slices{1};
int batch_count{1};
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
int swizzle_size{1};
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
// gemm with parallel interleaved reduction
// gemm epilogue (alpha, beta) = (1.0, 0.0)
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
std::vector<uint8_t> alpha_one;
std::vector<uint8_t> beta_zero;
bool use_pdl{false};
//
// Methods
//
/// Parses the problem
Status parse(
library::BlockScaledGemmDescription const &operation_desc,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Total number of bytes loaded
int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const;
/// Total number of flops computed
int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const;
/// Initializes a performance result
void initialize_result(
PerformanceResult &result,
library::BlockScaledGemmDescription const &operation_desc,
ProblemSpace const &problem_space);
};
/// Workspace used
struct GemmWorkspace {
DeviceAllocation *A{nullptr};
DeviceAllocation *SFA{nullptr};
DeviceAllocation *B{nullptr};
DeviceAllocation *SFB{nullptr};
DeviceAllocation *C{nullptr};
DeviceAllocation *Computed{nullptr};
DeviceAllocation *Reference{nullptr};
DeviceAllocation *Computed_SFD{nullptr};
DeviceAllocation *Reference_SFD{nullptr};
DeviceAllocation *Norm_constant{nullptr};
/// Number of copies of the problem workspace which are visited sequentially during
/// profiling to avoid camping in the last level cache.
int problem_count{1};
library::GemmUniversalConfiguration configuration;
library::BlockScaledGemmArguments arguments;
/// Buffer used for the operation's host workspace
std::vector<uint8_t> host_workspace;
/// Buffer used for the operations' device workspace
DeviceAllocation device_workspace;
/// Library configuration and arguments for reduction operator
library::ReductionConfiguration reduction_configuration;
library::ReductionArguments reduction_arguments;
/// Buffer used for the cutlass reduction operations' host workspace
std::vector<uint8_t> reduction_host_workspace;
};
protected:
//
// Data members
//
/// GEMM problem obtained from problem space
GemmProblem problem_;
/// Device memory allocations
GemmWorkspace gemm_workspace_;
/// CUTLASS parallel reduction operation to follow this* gemm operation
library::Operation const *reduction_op_;
public:
//
// Methods
//
/// Ctor
BlockScaledGemmOperationProfiler(Options const &options);
/// Destructor
virtual ~BlockScaledGemmOperationProfiler();
GemmProblem const& problem() const { return problem_; }
/// Prints usage statement for the math function
virtual void print_usage(std::ostream &out) const;
/// Prints examples
virtual void print_examples(std::ostream &out) const;
/// Extracts the problem dimensions
virtual Status initialize_configuration(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Initializes workspace
virtual Status initialize_workspace(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Verifies CUTLASS against references
virtual bool verify_cutlass(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Measures performance results
virtual bool profile(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
protected:
/// Initializes the performance result
void initialize_result_(
PerformanceResult &result,
Options const &options,
library::BlockScaledGemmDescription const &operation_desc,
ProblemSpace const &problem_space);
/// Verifies CUTLASS against references
bool verify_with_cublas_(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Verifies CUTLASS against host and device references
bool verify_with_reference_(
Options const &options,
PerformanceReport &report,
DeviceContext &device_context,
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem,
cutlass::library::NumericTypeID element_A,
cutlass::library::NumericTypeID element_B);
/// Method to profile a CUTLASS Operation
Status profile_cutlass_(
PerformanceResult &result,
Options const &options,
library::Operation const *operation,
void *arguments,
void *host_workspace,
void *device_workspace);
/// Initialize reduction problem dimensions and library::Operation
bool initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -73,6 +73,15 @@ public:
int64_t n{16};
int64_t k{16};
int cluster_m{1};
int cluster_n{1};
int cluster_k{1};
int cluster_m_fallback{1};
int cluster_n_fallback{1};
int cluster_k_fallback{1};
int64_t lda{0};
int64_t ldb{0};
int64_t ldc{0};
@ -86,6 +95,11 @@ public:
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
int swizzle_size{1};
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
// gemm with parallel interleaved reduction
// gemm epilogue (alpha, beta) = (1.0, 0.0)
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)

View File

@ -942,6 +942,18 @@ bool arg_as_IteratorAlgorithmID(
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr);
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RuntimeDatatype(
library::RuntimeDatatype &runtime_datatype,
char const *name,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem);
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr);

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,7 @@
// Profiler includes
#include "cutlass/profiler/cutlass_profiler.h"
#include "cutlass/profiler/gemm_operation_profiler.h"
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
#include "cutlass/profiler/rank_k_operation_profiler.h"
#include "cutlass/profiler/rank_2k_operation_profiler.h"
#include "cutlass/profiler/trmm_operation_profiler.h"
@ -60,6 +61,8 @@ CutlassProfiler::CutlassProfiler(
operation_profilers_.emplace_back(new GemmOperationProfiler(options));
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));

View File

@ -616,6 +616,48 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kFUE4M3:
cutlass::reference::device::BlockFillRandom<cutlass::float_ue4m3_t>(
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFUE8M0:
cutlass::reference::device::BlockFillRandom<cutlass::float_ue8m0_t>(
reinterpret_cast<cutlass::float_ue8m0_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::device::BlockFillRandom<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE3M2:
cutlass::reference::device::BlockFillRandom<cutlass::float_e3m2_t>(
reinterpret_cast<cutlass::float_e3m2_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE2M1:
cutlass::reference::device::BlockFillRandom<cutlass::float_e2m1_t>(
reinterpret_cast<cutlass::float_e2m1_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::device::BlockFillRandom<double>(
reinterpret_cast<double *>(pointer_),
@ -771,6 +813,50 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
dist
);
break;
case library::NumericTypeID::kFUE4M3:
cutlass::reference::host::BlockFillRandom<cutlass::float_ue4m3_t>(
reinterpret_cast<cutlass::float_ue4m3_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE3M2:
cutlass::reference::host::BlockFillRandom<cutlass::float_e3m2_t>(
reinterpret_cast<cutlass::float_e3m2_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFE2M1:
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m1_t>(
reinterpret_cast<cutlass::float_e2m1_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kFUE8M0:
cutlass::reference::host::BlockFillRandom<cutlass::float_ue8m0_t>(
reinterpret_cast<cutlass::float_ue8m0_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(host_data.data()),
@ -990,6 +1076,50 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
static_cast<cutlass::float_e5m2_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFUE4M3:
cutlass::reference::device::BlockFillSequential<cutlass::float_ue4m3_t>(
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
capacity_,
static_cast<cutlass::float_ue4m3_t>(dist.sequential.delta),
static_cast<cutlass::float_ue4m3_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
capacity_,
static_cast<cutlass::float_e2m3_t>(dist.sequential.delta),
static_cast<cutlass::float_e2m3_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE3M2:
cutlass::reference::device::BlockFillSequential<cutlass::float_e3m2_t>(
reinterpret_cast<cutlass::float_e3m2_t *>(pointer_),
capacity_,
static_cast<cutlass::float_e3m2_t>(dist.sequential.delta),
static_cast<cutlass::float_e3m2_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE2M1:
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m1_t>(
reinterpret_cast<cutlass::float_e2m1_t *>(pointer_),
capacity_,
static_cast<cutlass::float_e2m1_t>(dist.sequential.delta),
static_cast<cutlass::float_e2m1_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFUE8M0:
cutlass::reference::device::BlockFillSequential<cutlass::float_ue8m0_t>(
reinterpret_cast<cutlass::float_ue8m0_t *>(pointer_),
capacity_,
static_cast<cutlass::float_ue8m0_t>(dist.sequential.delta),
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::device::BlockFillSequential<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(pointer_),
@ -1220,6 +1350,50 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
static_cast<cutlass::float_e5m2_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFUE4M3:
cutlass::reference::host::BlockFillSequential<cutlass::float_ue4m3_t>(
reinterpret_cast<cutlass::float_ue4m3_t *>(host_data.data()),
capacity_,
static_cast<cutlass::float_ue4m3_t>(dist.sequential.delta),
static_cast<cutlass::float_ue4m3_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m3_t>(
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
capacity_,
static_cast<cutlass::float_e2m3_t>(dist.sequential.delta),
static_cast<cutlass::float_e2m3_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE3M2:
cutlass::reference::host::BlockFillSequential<cutlass::float_e3m2_t>(
reinterpret_cast<cutlass::float_e3m2_t *>(host_data.data()),
capacity_,
static_cast<cutlass::float_e3m2_t>(dist.sequential.delta),
static_cast<cutlass::float_e3m2_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFE2M1:
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m1_t>(
reinterpret_cast<cutlass::float_e2m1_t *>(host_data.data()),
capacity_,
static_cast<cutlass::float_e2m1_t>(dist.sequential.delta),
static_cast<cutlass::float_e2m1_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kFUE8M0:
cutlass::reference::host::BlockFillSequential<cutlass::float_ue8m0_t>(
reinterpret_cast<cutlass::float_ue8m0_t *>(host_data.data()),
capacity_,
static_cast<cutlass::float_ue8m0_t>(dist.sequential.delta),
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFillSequential<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(host_data.data()),
@ -1516,6 +1690,34 @@ bool DeviceAllocation::block_compare_equal(
reinterpret_cast<float_e5m2_t const *>(ptr_A),
reinterpret_cast<float_e5m2_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFUE4M3:
return reference::device::BlockCompareEqual<float_ue4m3_t>(
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
reinterpret_cast<float_ue4m3_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFUE8M0:
return reference::device::BlockCompareEqual<float_ue8m0_t>(
reinterpret_cast<float_ue8m0_t const *>(ptr_A),
reinterpret_cast<float_ue8m0_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFE2M3:
return reference::device::BlockCompareEqual<float_e2m3_t>(
reinterpret_cast<float_e2m3_t const *>(ptr_A),
reinterpret_cast<float_e2m3_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFE3M2:
return reference::device::BlockCompareEqual<float_e3m2_t>(
reinterpret_cast<float_e3m2_t const *>(ptr_A),
reinterpret_cast<float_e3m2_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kFE2M1:
return reference::device::BlockCompareEqual<float_e2m1_t>(
reinterpret_cast<float_e2m1_t const *>(ptr_A),
reinterpret_cast<float_e2m1_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kF16:
return reference::device::BlockCompareEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -1684,6 +1886,46 @@ bool DeviceAllocation::block_compare_relatively_equal(
capacity,
static_cast<float_e5m2_t>(epsilon),
static_cast<float_e5m2_t>(nonzero_floor));
case library::NumericTypeID::kFUE4M3:
return reference::device::BlockCompareRelativelyEqual<float_ue4m3_t>(
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
reinterpret_cast<float_ue4m3_t const *>(ptr_B),
capacity,
static_cast<float_ue4m3_t>(epsilon),
static_cast<float_ue4m3_t>(nonzero_floor));
case library::NumericTypeID::kFUE8M0:
return reference::device::BlockCompareRelativelyEqual<float_ue8m0_t>(
reinterpret_cast<float_ue8m0_t const *>(ptr_A),
reinterpret_cast<float_ue8m0_t const *>(ptr_B),
capacity,
static_cast<float_ue8m0_t>(epsilon),
static_cast<float_ue8m0_t>(nonzero_floor));
case library::NumericTypeID::kFE2M3:
return reference::device::BlockCompareRelativelyEqual<float_e2m3_t>(
reinterpret_cast<float_e2m3_t const *>(ptr_A),
reinterpret_cast<float_e2m3_t const *>(ptr_B),
capacity,
static_cast<float_e2m3_t>(epsilon),
static_cast<float_e2m3_t>(nonzero_floor));
case library::NumericTypeID::kFE3M2:
return reference::device::BlockCompareRelativelyEqual<float_e3m2_t>(
reinterpret_cast<float_e3m2_t const *>(ptr_A),
reinterpret_cast<float_e3m2_t const *>(ptr_B),
capacity,
static_cast<float_e3m2_t>(epsilon),
static_cast<float_e3m2_t>(nonzero_floor));
case library::NumericTypeID::kFE2M1:
return reference::device::BlockCompareRelativelyEqual<float_e2m1_t>(
reinterpret_cast<float_e2m1_t const *>(ptr_A),
reinterpret_cast<float_e2m1_t const *>(ptr_B),
capacity,
static_cast<float_e2m1_t>(epsilon),
static_cast<float_e2m1_t>(nonzero_floor));
case library::NumericTypeID::kF16:
return reference::device::BlockCompareRelativelyEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
@ -2026,6 +2268,27 @@ void DeviceAllocation::write_tensor_csv(
case library::NumericTypeID::kFE5M2:
write_tensor_csv_static_type<float_e5m2_t>(out, *this);
break;
case library::NumericTypeID::kFUE4M3:
write_tensor_csv_static_type<float_ue4m3_t>(out, *this);
break;
case library::NumericTypeID::kFE2M3:
write_tensor_csv_static_type<float_e2m3_t>(out, *this);
break;
case library::NumericTypeID::kFE3M2:
write_tensor_csv_static_type<float_e3m2_t>(out, *this);
break;
case library::NumericTypeID::kFE2M1:
write_tensor_csv_static_type<float_e2m1_t>(out, *this);
break;
case library::NumericTypeID::kFUE8M0:
write_tensor_csv_static_type<float_ue8m0_t>(out, *this);
break;
case library::NumericTypeID::kF16:
write_tensor_csv_static_type<half_t>(out, *this);
break;
@ -2193,6 +2456,27 @@ void DeviceAllocation::fill_device(double val = 0.0) {
case library::NumericTypeID::kFE5M2:
tensor_fill<float_e5m2_t>(*this, static_cast<float_e5m2_t>(val));
break;
case library::NumericTypeID::kFUE4M3:
tensor_fill<float_ue4m3_t>(*this, static_cast<float_ue4m3_t>(val));
break;
case library::NumericTypeID::kFUE8M0:
tensor_fill<float_ue8m0_t>(*this, static_cast<float_ue8m0_t>(val));
break;
case library::NumericTypeID::kFE2M3:
tensor_fill<float_e2m3_t>(*this, static_cast<float_e2m3_t>(val));
break;
case library::NumericTypeID::kFE3M2:
tensor_fill<float_e3m2_t>(*this, static_cast<float_e3m2_t>(val));
break;
case library::NumericTypeID::kFE2M1:
tensor_fill<float_e2m1_t>(*this, static_cast<float_e2m1_t>(val));
break;
case library::NumericTypeID::kF16:
tensor_fill<half_t>(*this, static_cast<half_t>(val));
break;
@ -2288,6 +2572,47 @@ void DeviceAllocation::fill_host(double val = 0.0) {
std::vector<uint8_t> host_data(bytes());
switch (this->type()) {
case library::NumericTypeID::kFUE4M3:
cutlass::reference::host::BlockFill<float_ue4m3_t>(
reinterpret_cast<float_ue4m3_t *>(host_data.data()),
capacity_,
static_cast<float_ue4m3_t>(val)
);
break;
case library::NumericTypeID::kFUE8M0:
cutlass::reference::host::BlockFill<float_ue8m0_t>(
reinterpret_cast<float_ue8m0_t *>(host_data.data()),
capacity_,
static_cast<float_ue8m0_t>(val)
);
break;
case library::NumericTypeID::kFE2M3:
cutlass::reference::host::BlockFill<float_e2m3_t>(
reinterpret_cast<float_e2m3_t *>(host_data.data()),
capacity_,
static_cast<float_e2m3_t>(val)
);
break;
case library::NumericTypeID::kFE3M2:
cutlass::reference::host::BlockFill<float_e3m2_t>(
reinterpret_cast<float_e3m2_t *>(host_data.data()),
capacity_,
static_cast<float_e3m2_t>(val)
);
break;
case library::NumericTypeID::kFE2M1:
cutlass::reference::host::BlockFill<float_e2m1_t>(
reinterpret_cast<float_e2m1_t *>(host_data.data()),
capacity_,
static_cast<float_e2m1_t>(val)
);
break;
case library::NumericTypeID::kFE4M3:
cutlass::reference::host::BlockFill<float_e4m3_t>(
reinterpret_cast<float_e4m3_t *>(host_data.data()),

View File

@ -104,6 +104,25 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
case library::NumericTypeID::kFE5M2:
data_distribution.set_uniform(-1, 1, 0);
break;
case library::NumericTypeID::kFE2M3:
data_distribution.set_uniform(-2, 2, 0);
break;
case library::NumericTypeID::kFE3M2:
data_distribution.set_uniform(-2, 2, 0);
break;
case library::NumericTypeID::kFE2M1:
data_distribution.set_uniform(-2, 2, 0);
break;
case library::NumericTypeID::kFUE8M0:
data_distribution.set_uniform(1, 4, 0);
break;
case library::NumericTypeID::kFUE4M3:
data_distribution.set_uniform(1, 4, 0);
break;
case library::NumericTypeID::kF16:
data_distribution.set_uniform(-3, 3, 0);
break;

View File

@ -76,6 +76,8 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
{ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"},
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
},
@ -172,6 +174,38 @@ Status GemmOperationProfiler::GemmProblem::parse(
this->k = 1024;
}
if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) {
// default value
this->cluster_m = 1;
}
if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) {
// default value
this->cluster_n = 1;
}
if (!arg_as_int(this->cluster_k, "cluster_k", problem_space, problem)) {
// default value
this->cluster_k = 1;
}
if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) {
// default value
this->cluster_m_fallback = 0;
}
if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) {
// default value
this->cluster_n_fallback = 0;
}
if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) {
// default value
this->cluster_k_fallback = 0;
}
if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) {
// default value
this->use_pdl = false;
@ -192,6 +226,18 @@ Status GemmOperationProfiler::GemmProblem::parse(
this->split_k_slices = 1;
}
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) {
// default value
this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic;
}
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) {
// default value
this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic;
}
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
// default value
this->batch_count = 1;
@ -338,6 +384,15 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "n", problem_space, n);
set_argument(result, "k", problem_space, k);
set_argument(result, "cluster_m", problem_space, cluster_m);
set_argument(result, "cluster_n", problem_space, cluster_n);
set_argument(result, "cluster_k", problem_space, cluster_k);
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "split_k_slices", problem_space, split_k_slices);
set_argument(result, "batch_count", problem_space, batch_count);
@ -345,6 +400,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "swizzle_size", problem_space, swizzle_size);
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b));
set_argument(result, "alpha", problem_space,
library::lexical_cast(alpha, operation_desc.element_epilogue));
@ -388,6 +448,14 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_[i].configuration.problem_size.m() = int(problem_.m);
gemm_workspace_[i].configuration.problem_size.n() = int(problem_.n);
gemm_workspace_[i].configuration.problem_size.k() = int(problem_.k);
gemm_workspace_[i].configuration.cluster_shape.m() = int(problem_.cluster_m);
gemm_workspace_[i].configuration.cluster_shape.n() = int(problem_.cluster_n);
gemm_workspace_[i].configuration.cluster_shape.k() = int(problem_.cluster_k);
gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback);
gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback);
gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback);
gemm_workspace_[i].configuration.lda = problem_.lda;
gemm_workspace_[i].configuration.ldb = problem_.ldb;
gemm_workspace_[i].configuration.ldc = problem_.ldc;
@ -423,6 +491,15 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost;
gemm_workspace_[i].arguments.swizzle_size = problem_.swizzle_size;
gemm_workspace_[i].arguments.raster_order = problem_.raster_order;
gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices;
gemm_workspace_[i].arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a;
gemm_workspace_[i].arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b;
initialize_result_(this->model_result_, options, operation_desc, problem_space);
if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) {
return can_implement;
@ -621,6 +698,9 @@ Status GemmOperationProfiler::initialize_workspace(
if (options.execution_mode != ExecutionMode::kDryRun) {
// NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels
gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)};
gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices;
gemm_workspace_[i].arguments.batch_count = problem_.batch_count;
gemm_workspace_[i].arguments.lda = problem_.lda;
gemm_workspace_[i].arguments.ldb = problem_.ldb;
@ -857,12 +937,32 @@ bool GemmOperationProfiler::verify_cutlass(
}
#endif // #if CUTLASS_ENABLE_CUBLAS
cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.front().arguments.runtime_input_datatype_a;
cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.front().arguments.runtime_input_datatype_b;
bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic;
bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic;
assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static.");
library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const &>(operation->description());
cutlass::library::NumericTypeID element_A = gemm_desc.A.element;
cutlass::library::NumericTypeID element_B = gemm_desc.B.element;
if (is_runtime_datatype_a) {
element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a);
}
if (is_runtime_datatype_b) {
element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b);
}
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B);
// Update disposition to worst case verification outcome among all
@ -1087,6 +1187,14 @@ bool GemmOperationProfiler::verify_with_reference_(
gemm_workspace_[i].configuration.problem_size.m(),
gemm_workspace_[i].configuration.problem_size.n(),
gemm_workspace_[i].configuration.problem_size.k(),
gemm_workspace_[i].configuration.cluster_shape.m(),
gemm_workspace_[i].configuration.cluster_shape.n(),
gemm_workspace_[i].configuration.cluster_shape.k(),
gemm_workspace_[i].configuration.cluster_shape_fallback.m(),
gemm_workspace_[i].configuration.cluster_shape_fallback.n(),
gemm_workspace_[i].configuration.cluster_shape_fallback.k(),
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,

View File

@ -91,6 +91,11 @@ OperationProfiler::OperationProfiler(
{ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"},
{ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"},
{ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"},
{ArgumentTypeID::kInteger, {"cluster_m_fallback", "cluster-shape-fallback::m"}, "Fallback Cluster shape in the M dimension"},
{ArgumentTypeID::kInteger, {"cluster_n_fallback", "cluster-shape-fallback::n"}, "Fallback Cluster shape in the N dimension"},
{ArgumentTypeID::kInteger, {"cluster_k_fallback", "cluster-shape-fallback::k"}, "Fallback Cluster shape in the K dimension"},
{ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"},
{ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"},
{ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"},
@ -174,6 +179,11 @@ bool OperationProfiler::satisfies(
return false;
}
}
bool dynamic_cluster = int64_t(op_desc.tile_description.cluster_shape.m()) == 0 ||
int64_t(op_desc.tile_description.cluster_shape.n()) == 0 ||
int64_t(op_desc.tile_description.cluster_shape.k()) == 0;
int64_t int_value;
if (arg_as_int(int_value, "inst_m", problem_space, problem)) {
@ -212,6 +222,7 @@ bool OperationProfiler::satisfies(
}
}
if (!dynamic_cluster) {
if (arg_as_int(int_value, "cluster_m", problem_space, problem)) {
if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) {
return false;
@ -230,6 +241,7 @@ bool OperationProfiler::satisfies(
}
}
}
if (arg_as_int(int_value, "stages", problem_space, problem)) {
if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) {
return false;
@ -296,6 +308,11 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind provider) {
if (provider == library::OperationKind::kGemm) {
out << "kGemm";
}
else if (provider == library::OperationKind::kBlockScaledGemm) {
out << "kBlockScaledGemm";
}
else if (provider == library::OperationKind::kRankK) {
out << "kRankK";
}

View File

@ -33,6 +33,7 @@
*/
#include <algorithm>
#include <fstream>
#include <set>
#include "cutlass/cutlass.h"
@ -810,16 +811,27 @@ Options::Options(cutlass::CommandLine const &cmdline):
}
else if (cmdline.check_cmd_line_flag("kernels")) {
cmdline.get_cmd_line_arguments("kernels", operation_names);
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
}
if (cmdline.check_cmd_line_flag("kernels-file")) {
std::string filename;
cmdline.get_cmd_line_argument("kernels-file", filename, {});
std::ifstream input(filename);
if (!input.good()) {
throw std::runtime_error("failed to open: " + filename);
}
for (std::string line; getline(input, line);) {
operation_names.push_back(line);
}
}
if (cmdline.check_cmd_line_flag("ignore-kernels")) {
cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names);
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
}
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
// Prevent launches on the device for anything other than CUTLASS operation
// Allow verification only on host
if (execution_mode == ExecutionMode::kTrace) {
@ -856,6 +868,11 @@ void Options::print_usage(std::ostream &out) const {
<< " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line
<< " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n"
<< " --kernels-file=<filename> "
<< " Same behavior as --kernels, but kernel names are specified in a file" << end_of_line
<< " with one kernel on each line. Set of profiled kernels is the union of kernels specified" << end_of_line
<< " here and those specified in `kernels`.\n\n"
<< " --ignore-kernels=<string_list> "
<< " Excludes kernels whose names match anything in this list.\n\n"
;

View File

@ -879,6 +879,32 @@ bool arg_as_NumericTypeID(
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RuntimeDatatype(
library::RuntimeDatatype &runtime_datatype,
KernelArgument::Value const *value_ptr) {
if (value_ptr->not_null) {
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
runtime_datatype = library::from_string<library::RuntimeDatatype>(
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
if (runtime_datatype == library::RuntimeDatatype::kInvalid) {
throw std::runtime_error(
"arg_as_RuntimeDatatype() - illegal cast.");
}
}
else {
throw std::runtime_error(
"arg_as_RuntimeDatatype() - illegal cast.");
}
return true;
}
return false;
}
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RasterOrder(
library::RasterOrder &raster_order,
@ -945,6 +971,21 @@ bool arg_as_LayoutTypeID(
return false;
}
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_RuntimeDatatype(
library::RuntimeDatatype &runtime_datatype,
char const *name,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
size_t idx = problem_space.argument_index(name);
KernelArgument::Value const *value_ptr = problem.at(idx).get();
return arg_as_RuntimeDatatype(runtime_datatype, value_ptr);
}
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
bool arg_as_LayoutTypeID(
library::LayoutTypeID &layout_type,

View File

@ -922,7 +922,7 @@ __global__ void Conv3dWgrad(
filter_s = problem_size.S - 1 - filter_s;
}
int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d;
int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;

View File

@ -352,7 +352,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}
@ -367,7 +367,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};
@ -389,7 +389,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}
@ -404,7 +404,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};

View File

@ -42,6 +42,7 @@
#include "cutlass/relatively_equal.h"
#include "cute/tensor.hpp"
#include "cute/pointer.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -59,10 +60,20 @@ struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T
/////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
//
// Gett Mainloop Parameters
//
///////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_ // (N, K, L)
, class TensorSfA_ = TensorA_,
class TensorSfB_ = TensorB_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
@ -79,23 +90,105 @@ struct GettMainloopParams {
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
using TensorSfA = TensorSfA_;
using TensorSfB = TensorSfB_;
using EngineSfA = typename TensorSfA::engine_type;
using LayoutSfA = typename TensorSfA::layout_type;
using EngineSfB = typename TensorSfB::engine_type;
using LayoutSfB = typename TensorSfB::layout_type;
TensorSfA_ SfA{};
TensorSfB_ SfB{};
GettMainloopParams() {}
GettMainloopParams(TensorA tensor_A, TensorB tensor_B)
: A(tensor_A), B(tensor_B) {}
GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
: A(tensor_A), SfA(tensor_SfA),
B(tensor_B), SfB(tensor_SfB) {}
};
////////////////////////////////////////////////////////////////////////
//
// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels
//
////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorSfA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorSfB_ // (N, K, L)
>
struct GettBlockScalingMainloopParams : public GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_> {
using Base = GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_>;
using ElementAccumulator = typename Base::ElementAccumulator;
using TensorA = typename Base::TensorA;
using TensorB = typename Base::TensorB;
using EngineA = typename Base::EngineA;
using LayoutA = typename Base::LayoutA;
using EngineB = typename Base::EngineB;
using LayoutB = typename Base::LayoutB;
ComplexTransform transform_A = Base::transform_A;
ComplexTransform transform_B = Base::transform_B;
using TensorSfA = typename Base::TensorSfA;
using TensorSfB = typename Base::TensorSfB;
using EngineSfA = typename Base::EngineSfA;
using LayoutSfA = typename Base::LayoutSfA;
using EngineSfB = typename Base::EngineSfB;
using LayoutSfB = typename Base::LayoutSfB;
GettBlockScalingMainloopParams() {}
GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
: Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class SfStrategy {
None = 0,
SfDGen = 1
};
///////////////////////////////////////////////////////////
//
// Gett Epilogue Parameters
//
///////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
class TensorAux_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, N, L)
class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class TensorSFD_ = TensorD_,
class SFD_VectorSize_ = cute::Int<0>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
,
SfStrategy SfGenStrategy_ = SfStrategy::None
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
@ -108,6 +201,8 @@ struct GettEpilogueParams {
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using TensorSFD = TensorSFD_;
using SFD_VectorSize = SFD_VectorSize_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
@ -115,7 +210,11 @@ struct GettEpilogueParams {
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
using EngineSfD = typename TensorSFD::engine_type;
using LayoutSfD = typename TensorSFD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
static constexpr SfStrategy SfGenStrategy = SfGenStrategy_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
@ -125,7 +224,8 @@ struct GettEpilogueParams {
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
TensorSFD SfD{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
@ -137,8 +237,250 @@ struct GettEpilogueParams {
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
GettEpilogueParams() {}
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {}
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {}
GettEpilogueParams(
ElementScalar alpha, ElementScalar beta,
TensorC tensor_C, TensorD tensor_D,
VectorBias bias, TensorAux tensor_aux,
VectorAlpha vector_alpha, VectorBeta vector_beta)
: alpha(alpha), beta(beta),
C(tensor_C), D(tensor_D),
Bias(bias), Aux(tensor_aux),
Valpha(vector_alpha), Vbeta(vector_beta) {}
};
////////////////////////////////////////////////////////////////////////
//
// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels
//
////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_,
class TensorD_,
class TensorSfD_ = TensorD_,
class SFD_VectorSize_ = cute::Int<0>,
SfStrategy SfGenStrategy_ = SfStrategy::None
>
struct GettBlockScalingEpilogueParams : public GettEpilogueParams<
ElementScalar_, // ElementScalar
ElementScalar_, // ElementScalingFactor
ElementAccumulator_, // ElementAccumulator
ElementCompute_, // ElementCompute
TensorC_, // TensorC (M, N, L)
TensorD_, // TensorD (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
cutlass::epilogue::thread::Identity<ElementCompute_>, //
TensorSfD_, // TensorSfD
SFD_VectorSize_, // SFD_VectorSize
cutlass::plus<ElementCompute_>, // class BiasBinaryOp_ =
false, //PerColumnBias_
SfGenStrategy_ // SfGenStrategy
> {
using Base = GettEpilogueParams<
ElementScalar_, // ElementScalar
ElementScalar_, // ElementScalingFactor
ElementAccumulator_, // ElementAccumulator
ElementCompute_, // ElementCompute
TensorC_, // TensorC (M, N, L)
TensorD_, // TensorD (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
cutlass::epilogue::thread::Identity<ElementCompute_>, //
TensorSfD_, // TensorSfD
SFD_VectorSize_, // SFD_VectorSize
cutlass::plus<ElementCompute_>, // BiasBinaryOp
false, // PerColumnBias
SfGenStrategy_ // SfGenStrategy
>;
using ElementScalar = typename Base::ElementScalar;
using ElementScalingFactor = typename Base::ElementScalingFactor;
using ElementAccumulator = typename Base::ElementAccumulator;
using ElementCompute = typename Base::ElementCompute;
using TensorC = typename Base::TensorC;
using TensorD = typename Base::TensorD;
using TensorAux = typename Base::TensorAux;
using VectorBias = typename Base::VectorBias;
using VectorAlpha = typename Base::VectorAlpha;
using VectorBeta = typename Base::VectorBeta;
using TensorSFD = typename Base::TensorSFD;
using SFD_VectorSize = typename Base::SFD_VectorSize;
using ActivationFunctor = typename Base::ActivationFunctor;
using BiasBinaryOp = typename Base::BiasBinaryOp;
using EngineC = typename Base::EngineC;
using LayoutC = typename Base::LayoutC;
using EngineD = typename Base::EngineD;
using LayoutD = typename Base::LayoutD;
using EngineSfD = typename Base::EngineSfD;
using LayoutSfD = typename Base::LayoutSfD;
static constexpr bool PerColumnBias = Base::PerColumnBias;
static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy;
GettBlockScalingEpilogueParams() {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
: Base(alpha, beta, tensor_C, tensor_D) {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD)
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {}
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {}
};
///////////////////////////////////////////////////////////
//
// Generic Gett 3x Implementation
//
///////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
template <int kVectorSize, class EpilogueParams, class TensorD, class TensorSFD, class ElementCompute, int kBlockM, int kBlockN>
void compute_1d_scaling_factor_and_quantized_output(
EpilogueParams const& epilogue_params,
TensorD &tensor_D,
TensorSFD &tensor_SfD,
int64_t m,
int64_t n,
int64_t l,
ElementCompute (&acc)[kBlockM][kBlockN])
{
using ElementD = typename ElementTraits<typename EpilogueParams::EngineD::value_type>::type;
using ElementSfD = typename ElementTraits<typename EpilogueParams::EngineSfD::value_type>::type;
int const M = cute::size<0>(tensor_D.layout());
int const N = cute::size<1>(tensor_D.layout());
int const L = cute::size<2>(tensor_D.layout());
auto mul = cutlass::multiplies<ElementCompute>{};
auto div = divides<ElementCompute>{};
// Get FP max
ElementCompute fp_max = ElementCompute(std::numeric_limits<ElementD>::max());
float scale_down_factor = div(1.0f, fp_max);
// Get st' = st / FP max
ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor);
absolute_value_op<ElementCompute> abs_op;
maximum_with_nan_propogation<ElementCompute> max_op;
if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) {
// MN major output
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
// Col major output
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
int64_t col = n + n_b;
/// Step1: get max across a vector
ElementCompute accum_max = ElementCompute(0);
for (int v = 0; v < kVectorSize; v++) {
int accum_row = v_b * kVectorSize + v;
int64_t output_row = accum_row + m;
if (output_row < M && col < N) {
accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b]));
}
}
/// Step2: Compute Scale
ElementCompute pvscale = mul(accum_max, st_scaled_down);
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
// Store the Scaling Factors
int64_t sf_row = m + kVectorSize * v_b;
if (sf_row < M && col < N) {
tensor_SfD(sf_row, col, l) = qpvscale;
}
/// Step3: Compute quantized output values
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
// Get float reciprocal
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
// Map INF to fp32::max
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
// Store the intermediate_accum
for (int v = 0; v < kVectorSize; v++) {
int accum_row = v_b * kVectorSize + v;
int64_t output_row = accum_row + m;
if (output_row < M && col < N) {
acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale);
}
}
}
}
}
else {
int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize);
// row major output
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
int64_t row = m + m_b;
/// Step1: get max across a vector
ElementCompute accum_max = ElementCompute(0);
for (int v = 0; v < kVectorSize; v++) {
int accum_col = v_b * kVectorSize + v;
int64_t output_col = accum_col + n;
if (row < M && output_col < N) {
accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col]));
}
}
/// Step2: Compute Scale
ElementCompute pvscale = mul(accum_max, st_scaled_down);
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
// Store the Scaling Factors
int64_t sf_col = n + kVectorSize * v_b;
if (row < M && sf_col < N) {
tensor_SfD(row, sf_col, l) = qpvscale;
}
/// Step3: Compute quantized output values
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
// Get float reciprocal
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
// Map INF to fp32::max
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
// Store the intermediate_accum
for (int v = 0; v < kVectorSize; v++) {
int accum_col = v_b * kVectorSize + v;
int64_t output_col = accum_col + n;
if (row < M && output_col < N) {
acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale);
}
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel
@ -188,6 +530,11 @@ void gett_mainloop(
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementSFA = typename ElementTraits<typename MainloopParams::EngineSfA::value_type>::type;
using ElementSFB = typename ElementTraits<typename MainloopParams::EngineSfB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
@ -207,6 +554,14 @@ void gett_mainloop(
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
if constexpr (not cute::is_same_v<ElementSFA, ElementA>){
// Load SFA
auto sfa = static_cast<ElementAccumulator>(mainloop_params.SfA(m + m_b, k, l));
a_frag[m_b] *= sfa;
}
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
a_frag[m_b] = conj(a_frag[m_b]);
}
@ -222,6 +577,14 @@ void gett_mainloop(
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
if constexpr (not cute::is_same_v<ElementSFB, ElementB>){
// Load SFB
auto sfb = static_cast<ElementAccumulator>(mainloop_params.SfB(n + n_b, k, l));
b_frag[n_b] *= sfb;
}
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
b_frag[n_b] = conj(b_frag[n_b]);
}
@ -259,6 +622,7 @@ void gett_epilogue(
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementSfD = typename EpilogueParams::TensorSFD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
@ -267,6 +631,8 @@ void gett_epilogue(
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
@ -412,6 +778,17 @@ void gett_epilogue(
}
}
} // m_b
if constexpr (
SfGenStrategy == SfStrategy::SfDGen
) {
// 1d scale factor generation
constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{};
if (epilogue_params.SfD.data() != nullptr) {
compute_1d_scaling_factor_and_quantized_output<kVectorSize>(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum);
}
}
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {