CUTLASS 3.8 Release (#2059)
* CUTLASS 3.8 Release
* update
* Update README.md
* Revert "Update README.md"
This reverts commit b353e36fe8.
* update
* update
---------
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -221,6 +221,19 @@ cutlass_add_cutlass_library(
|
||||
|
||||
# files split for parallel compilation
|
||||
src/reference/gemm_int4.cu
|
||||
|
||||
src/reference/block_scaled_gemm_fp4a_vs16.cu
|
||||
src/reference/block_scaled_gemm_fp4a_vs32.cu
|
||||
src/reference/block_scaled_gemm_mixed8bitsa.cu
|
||||
src/reference/gemm_f4_f4_f32.cu
|
||||
src/reference/gemm_f4_f6_f32.cu
|
||||
src/reference/gemm_f4_f8_f32.cu
|
||||
src/reference/gemm_f6_f4_f32.cu
|
||||
src/reference/gemm_f6_f6_f32.cu
|
||||
src/reference/gemm_f6_f8_f32.cu
|
||||
src/reference/gemm_f8_f4_f32.cu
|
||||
src/reference/gemm_f8_f6_f32.cu
|
||||
|
||||
src/reference/gemm_s8_s8_s32.cu
|
||||
src/reference/gemm_u8_u8_s32.cu
|
||||
src/reference/gemm_int8_interleaved_32.cu
|
||||
|
||||
@ -119,6 +119,18 @@ template <> struct ArchMap<arch::Sm90, arch::OpClassSparseTensorOp> {
|
||||
static int const kMax = 90;
|
||||
};
|
||||
|
||||
|
||||
template <typename OperatorClass> struct ArchMap<arch::Sm100, OperatorClass> {
|
||||
static int const kMin = 100;
|
||||
static int const kMax = 1024;
|
||||
};
|
||||
|
||||
template <> struct ArchMap<arch::Sm100, arch::OpClassTensorOp> {
|
||||
static int const kMin = 100;
|
||||
static int const kMax = 100;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
|
||||
@ -300,6 +300,101 @@ struct GemmDescription : public OperationDescription {
|
||||
transform_B(transform_B) {}
|
||||
};
|
||||
|
||||
|
||||
/// Description of all GEMM computations
|
||||
struct BlockScaledGemmDescription : public OperationDescription {
|
||||
|
||||
/// Indicates the kind of GEMM performed
|
||||
GemmKind gemm_kind;
|
||||
|
||||
/// Describes the A operand
|
||||
TensorDescription A;
|
||||
|
||||
/// Describes the B operand
|
||||
TensorDescription B;
|
||||
|
||||
/// Describes the source matrix
|
||||
TensorDescription C;
|
||||
|
||||
/// Describes the destination matrix
|
||||
TensorDescription D;
|
||||
|
||||
/// Describes the SFA operand
|
||||
TensorDescription SFA;
|
||||
|
||||
/// Describes the SFB operand
|
||||
TensorDescription SFB;
|
||||
|
||||
/// Describes the SFD operand
|
||||
TensorDescription SFD;
|
||||
|
||||
/// Describes the data type of the scalars passed to the epilogue
|
||||
NumericTypeID element_epilogue;
|
||||
|
||||
/// Describes the structure of parallel reductions
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
/// Transformation on A operand
|
||||
ComplexTransform transform_A;
|
||||
|
||||
/// Transformation on B operand
|
||||
ComplexTransform transform_B;
|
||||
|
||||
/// Describes the input ScaleFactor VectorSize
|
||||
int SFVecSize;
|
||||
|
||||
/// Describes the Output ScaleFactor VectorSize
|
||||
int EpilogueSFVecSize;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
BlockScaledGemmDescription(
|
||||
GemmKind gemm_kind = GemmKind::kGemm,
|
||||
TensorDescription const& A = TensorDescription(),
|
||||
TensorDescription const& B = TensorDescription(),
|
||||
TensorDescription const& C = TensorDescription(),
|
||||
TensorDescription const& D = TensorDescription(),
|
||||
NumericTypeID element_epilogue = NumericTypeID::kInvalid,
|
||||
SplitKMode split_k_mode = SplitKMode::kNone,
|
||||
ComplexTransform transform_A = ComplexTransform::kNone,
|
||||
ComplexTransform transform_B = ComplexTransform::kNone
|
||||
):
|
||||
gemm_kind(gemm_kind),
|
||||
A(A),
|
||||
B(B),
|
||||
C(C),
|
||||
D(D),
|
||||
element_epilogue(element_epilogue),
|
||||
split_k_mode(split_k_mode),
|
||||
transform_A(transform_A),
|
||||
transform_B(transform_B) {}
|
||||
|
||||
BlockScaledGemmDescription(
|
||||
OperationDescription op_desc,
|
||||
GemmKind gemm_kind,
|
||||
TensorDescription const& A,
|
||||
TensorDescription const& B,
|
||||
TensorDescription const& C,
|
||||
TensorDescription const& D,
|
||||
NumericTypeID element_epilogue,
|
||||
SplitKMode split_k_mode,
|
||||
ComplexTransform transform_A,
|
||||
ComplexTransform transform_B
|
||||
):
|
||||
OperationDescription(op_desc),
|
||||
gemm_kind(gemm_kind),
|
||||
A(A),
|
||||
B(B),
|
||||
C(C),
|
||||
D(D),
|
||||
element_epilogue(element_epilogue),
|
||||
split_k_mode(split_k_mode),
|
||||
transform_A(transform_A),
|
||||
transform_B(transform_B) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Description for structured sparse GEMMs.
|
||||
|
||||
@ -178,6 +178,15 @@ public:
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
int cluster_m, /// cluster shape M dimension
|
||||
int cluster_n, /// cluster shape N dimension
|
||||
int cluster_k, /// cluster shape K dimension
|
||||
int cluster_m_fallback, /// Fallback cluster shape M dimension
|
||||
int cluster_n_fallback, /// Fallback cluster shape N dimension
|
||||
int cluster_k_fallback, /// Fallback cluster shape K dimension
|
||||
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
|
||||
@ -103,6 +103,7 @@ public:
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const = 0;
|
||||
|
||||
// Originally designed for metadata, but should be useful for FP8/6/4 too.
|
||||
virtual Status initialize_with_profiler_workspace(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
@ -269,6 +270,8 @@ struct GemmUniversalConfiguration {
|
||||
|
||||
GemmUniversalMode mode{GemmUniversalMode::kGemm};
|
||||
gemm::GemmCoord problem_size{};
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
int batch_count{1};
|
||||
|
||||
int64_t lda{0};
|
||||
@ -282,6 +285,8 @@ struct GemmUniversalConfiguration {
|
||||
struct GemmUniversalArguments {
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
gemm::GemmCoord problem_size{};
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
int batch_count{1};
|
||||
|
||||
void const *A{nullptr};
|
||||
@ -307,13 +312,68 @@ struct GemmUniversalArguments {
|
||||
// Needed for some 3.x kernels
|
||||
int sm_count{0};
|
||||
library::RasterOrder raster_order{};
|
||||
library::RuntimeDatatype runtime_input_datatype_a{};
|
||||
library::RuntimeDatatype runtime_input_datatype_b{};
|
||||
int swizzle_size{1};
|
||||
int split_k_slices{1};
|
||||
|
||||
int device_index{0};
|
||||
|
||||
bool use_pdl{false};
|
||||
};
|
||||
|
||||
|
||||
/// Block Scaled GEMM
|
||||
//
|
||||
// OperationKind: kBlockScaledGemm
|
||||
// GemmKind: Universal
|
||||
|
||||
struct BlockScaledGemmArguments {
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
gemm::GemmCoord problem_size{};
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
int batch_count{1};
|
||||
|
||||
void const *A{nullptr};
|
||||
void const *B{nullptr};
|
||||
void const *SFA{nullptr};
|
||||
void const *SFB{nullptr};
|
||||
void const *C{nullptr};
|
||||
void *D{nullptr};
|
||||
void *SFD{nullptr};
|
||||
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
int64_t ldd{0};
|
||||
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_C{0};
|
||||
int64_t batch_stride_D{0};
|
||||
|
||||
// Needed for ScaleFactor Generation
|
||||
void const *norm_constant{nullptr};
|
||||
|
||||
// Needed for some 3.x kernels
|
||||
int sm_count{0};
|
||||
library::RasterOrder raster_order{};
|
||||
int swizzle_size{1};
|
||||
int split_k_slices{1};
|
||||
|
||||
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
|
||||
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
|
||||
|
||||
bool use_pdl{false};
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Complex valued GEMM in which real and imaginary parts are separated by a stride
|
||||
|
||||
@ -243,6 +243,191 @@ using GemmOperationFunctionalMap = std::unordered_map<
|
||||
GemmFunctionalKeyHasher
|
||||
>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Data Structures for BlockScaled Gemm Functional Maps
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tuple uniquely identifying Gemm functional behavior
|
||||
struct BlockScaledGemmFunctionalKey {
|
||||
|
||||
Provider provider;
|
||||
GemmKind gemm_kind;
|
||||
OperationKind kind;
|
||||
NumericTypeID element_compute;
|
||||
NumericTypeID element_scalar;
|
||||
NumericTypeID element_A;
|
||||
LayoutTypeID layout_A;
|
||||
NumericTypeID element_SFA;
|
||||
NumericTypeID element_B;
|
||||
LayoutTypeID layout_B;
|
||||
NumericTypeID element_SFB;
|
||||
NumericTypeID element_C;
|
||||
LayoutTypeID layout_C;
|
||||
NumericTypeID element_D;
|
||||
LayoutTypeID layout_D;
|
||||
NumericTypeID element_SFD;
|
||||
LayoutTypeID layout_SFD;
|
||||
int SFVecSize;
|
||||
int EpilogueSFVecSize;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
inline
|
||||
BlockScaledGemmFunctionalKey(
|
||||
Provider provider,
|
||||
GemmKind gemm_kind = GemmKind::kGemm,
|
||||
OperationKind kind = OperationKind::kBlockScaledGemm,
|
||||
NumericTypeID element_compute = NumericTypeID::kF32,
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32,
|
||||
NumericTypeID element_A = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_SFA = NumericTypeID::kF16,
|
||||
NumericTypeID element_B = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_SFB = NumericTypeID::kF16,
|
||||
NumericTypeID element_C = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_C = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_D = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_D = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_SFD = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor,
|
||||
int sf_vec_size = 32
|
||||
, int epilogue_sf_vec_size = 32
|
||||
):
|
||||
provider(provider),
|
||||
gemm_kind(gemm_kind),
|
||||
kind(kind),
|
||||
element_compute(element_compute),
|
||||
element_scalar(element_scalar),
|
||||
element_A(element_A),
|
||||
layout_A(layout_A),
|
||||
element_SFA(element_SFA),
|
||||
element_B(element_B),
|
||||
layout_B(layout_B),
|
||||
element_SFB(element_SFB),
|
||||
element_C(element_C),
|
||||
layout_C(layout_C),
|
||||
element_D(element_D),
|
||||
layout_D(layout_D),
|
||||
element_SFD(element_SFD),
|
||||
layout_SFD(layout_SFD),
|
||||
SFVecSize(sf_vec_size)
|
||||
, EpilogueSFVecSize(epilogue_sf_vec_size)
|
||||
{ }
|
||||
|
||||
inline
|
||||
bool operator==(BlockScaledGemmFunctionalKey const &rhs) const {
|
||||
return
|
||||
(provider == rhs.provider) &&
|
||||
(gemm_kind == rhs.gemm_kind) &&
|
||||
(kind == rhs.kind) &&
|
||||
(element_compute == rhs.element_compute) &&
|
||||
(element_scalar == rhs.element_scalar) &&
|
||||
(element_A == rhs.element_A) &&
|
||||
(layout_A == rhs.layout_A) &&
|
||||
(element_SFA == rhs.element_SFA) &&
|
||||
(element_B == rhs.element_B) &&
|
||||
(layout_B == rhs.layout_B) &&
|
||||
(element_SFB == rhs.element_SFB) &&
|
||||
(element_C == rhs.element_C) &&
|
||||
(layout_C == rhs.layout_C) &&
|
||||
(element_D == rhs.element_D) &&
|
||||
(layout_D == rhs.layout_D) &&
|
||||
(element_SFD == rhs.element_SFD) &&
|
||||
(layout_SFD == rhs.layout_SFD) &&
|
||||
(SFVecSize == rhs.SFVecSize)
|
||||
&& (EpilogueSFVecSize == rhs.EpilogueSFVecSize)
|
||||
;
|
||||
}
|
||||
|
||||
inline
|
||||
bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) {
|
||||
|
||||
out << "{\n"
|
||||
<< " provider: " << to_string(k.provider) << "\n"
|
||||
<< " gemm_kind: " << to_string(k.gemm_kind) << "\n"
|
||||
<< " kind: " << to_string(k.kind) << "\n"
|
||||
<< " element_compute: " << to_string(k.element_compute) << "\n"
|
||||
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
|
||||
<< " element_A: " << to_string(k.element_A) << "\n"
|
||||
<< " layout_A: " << to_string(k.layout_A) << "\n"
|
||||
<< " element_SFA: " << to_string(k.element_SFA) << "\n"
|
||||
<< " element_B: " << to_string(k.element_B) << "\n"
|
||||
<< " layout_B: " << to_string(k.layout_B) << "\n"
|
||||
<< " element_SFB: " << to_string(k.element_SFB) << "\n"
|
||||
<< " element_C: " << to_string(k.element_C) << "\n"
|
||||
<< " layout_C: " << to_string(k.layout_C) << "\n"
|
||||
<< " element_D: " << to_string(k.element_D) << "\n"
|
||||
<< " layout_D: " << to_string(k.layout_D) << "\n"
|
||||
<< " element_SFD: " << to_string(k.element_SFD) << "\n"
|
||||
<< " layout_SFD: " << to_string(k.layout_SFD) << "\n"
|
||||
<< " SFVecSize: " << k.SFVecSize << "\n"
|
||||
<< "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n"
|
||||
<< "}";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Hash function for BlockScaledGemmFunctionalKeyHasher
|
||||
struct BlockScaledGemmFunctionalKeyHasher {
|
||||
using IntHash = std::hash<int>;
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
|
||||
}
|
||||
|
||||
inline
|
||||
size_t operator()(BlockScaledGemmFunctionalKey const &key) const {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.gemm_kind)), 2) ^
|
||||
rotl(hash(int(key.kind)), 3) ^
|
||||
rotl(hash(int(key.element_compute)), 4) ^
|
||||
rotl(hash(int(key.element_scalar)), 5) ^
|
||||
rotl(hash(int(key.element_A)), 6) ^
|
||||
rotl(hash(int(key.layout_A)), 7) ^
|
||||
rotl(hash(int(key.element_SFA)), 8) ^
|
||||
rotl(hash(int(key.element_B)), 9) ^
|
||||
rotl(hash(int(key.layout_B)), 10) ^
|
||||
rotl(hash(int(key.element_SFB)), 11) ^
|
||||
rotl(hash(int(key.element_C)), 12) ^
|
||||
rotl(hash(int(key.layout_C)), 13) ^
|
||||
rotl(hash(int(key.element_D)), 14) ^
|
||||
rotl(hash(int(key.layout_D)), 15) ^
|
||||
rotl(hash(int(key.element_SFD)), 16) ^
|
||||
rotl(hash(int(key.layout_SFD)), 17) ^
|
||||
rotl(hash(int(key.SFVecSize)), 18) ^
|
||||
rotl(hash(int(key.EpilogueSFVecSize)), 19)
|
||||
;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
|
||||
using BlockScaledGemmOperationFunctionalMap = std::unordered_map<
|
||||
BlockScaledGemmFunctionalKey,
|
||||
GemmOperationVectorMap,
|
||||
BlockScaledGemmFunctionalKeyHasher
|
||||
>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Data Structures for Conv Functional Maps
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -509,6 +694,9 @@ public:
|
||||
// provider (kCUTLASS)
|
||||
GemmOperationFunctionalMap gemm_operations;
|
||||
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations;
|
||||
|
||||
/// Map of all operations of type kConv2d
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
ConvOperationFunctionalMap conv2d_operations;
|
||||
|
||||
@ -43,6 +43,7 @@ enum class LayoutTypeID {
|
||||
kUnknown,
|
||||
kColumnMajor,
|
||||
kRowMajor,
|
||||
kBlockScalingTensor,
|
||||
kColumnMajorInterleavedK2,
|
||||
kRowMajorInterleavedK2,
|
||||
kColumnMajorInterleavedK4,
|
||||
@ -83,6 +84,16 @@ enum class NumericTypeID {
|
||||
kS64,
|
||||
kFE4M3,
|
||||
kFE5M2,
|
||||
|
||||
kFE2M3,
|
||||
kFE3M2,
|
||||
kFE2M1,
|
||||
kFUE8M0,
|
||||
kFUE4M3,
|
||||
kF8,
|
||||
kF6,
|
||||
kF4,
|
||||
|
||||
kF16,
|
||||
kBF16,
|
||||
kTF32,
|
||||
@ -131,6 +142,7 @@ enum class Provider {
|
||||
/// Enumeration indicating the kind of operation
|
||||
enum class OperationKind {
|
||||
kGemm,
|
||||
kBlockScaledGemm,
|
||||
kRankK,
|
||||
kRank2K,
|
||||
kTrmm,
|
||||
@ -165,6 +177,7 @@ enum class OpcodeClassID {
|
||||
kTensorOp,
|
||||
kWmmaTensorOp,
|
||||
kSparseTensorOp,
|
||||
kBlockScaledOp,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
@ -188,6 +201,7 @@ enum class MathOperationID {
|
||||
/// Enumeration indicating what kind of GEMM operation to perform
|
||||
enum class GemmKind {
|
||||
kGemm,
|
||||
kBlockScaledGemm,
|
||||
kSparse,
|
||||
kUniversal,
|
||||
kPlanarComplex,
|
||||
@ -251,6 +265,20 @@ enum class EpilogueKind {
|
||||
kInvalid
|
||||
};
|
||||
|
||||
|
||||
enum class RuntimeDatatype {
|
||||
kStatic,
|
||||
kE4M3,
|
||||
kE5M2,
|
||||
|
||||
kE3M2,
|
||||
kE2M3,
|
||||
kE2M1,
|
||||
|
||||
kInvalid
|
||||
};
|
||||
|
||||
|
||||
enum class RasterOrder {
|
||||
kAlongN,
|
||||
kAlongM,
|
||||
|
||||
@ -170,6 +170,15 @@ char const *to_string(ConvKind type, bool pretty = false);
|
||||
template <>
|
||||
ConvKind from_string<ConvKind>(std::string const &str);
|
||||
|
||||
|
||||
/// Converts a RuntimeDatatype enumerant to a string
|
||||
char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false);
|
||||
|
||||
/// Convers a RuntimeDatatype enumerant from a string
|
||||
template<>
|
||||
cutlass::library::RuntimeDatatype from_string<cutlass::library::RuntimeDatatype>(std::string const &str);
|
||||
|
||||
|
||||
/// Converts a RasterOrder enumerant to a string
|
||||
char const *to_string(RasterOrder type, bool pretty = false);
|
||||
|
||||
@ -202,6 +211,8 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
|
||||
/// Casts from a real value represented as a double to the destination type. Returns true if successful.
|
||||
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
|
||||
|
||||
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
|
||||
450
tools/library/src/block_scaled_gemm_operation_3x.hpp
Normal file
450
tools/library/src/block_scaled_gemm_operation_3x.hpp
Normal file
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines operations for all GEMM operation kinds in CUTLASS Library.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "gemm_operation_3x.hpp"
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
using ElementA = typename Operator::CollectiveMainloop::ElementA;
|
||||
using ElementSFA = typename Operator::CollectiveMainloop::ElementSF;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::CollectiveMainloop::ElementB;
|
||||
using ElementSFB = typename Operator::CollectiveMainloop::ElementSF;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementD = typename Operator::ElementD;
|
||||
using LayoutD = typename Operator::LayoutD;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
|
||||
constexpr static int SFVecSize = TiledMma::SFVecSize;
|
||||
|
||||
using CollectiveMainloop = typename Operator::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v<typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
|
||||
static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize;
|
||||
using ElementSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
|
||||
using LayoutSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::GmemLayoutTagScalefactor, LayoutD>;
|
||||
|
||||
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
||||
|
||||
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
||||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
|
||||
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
|
||||
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA;
|
||||
using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB;
|
||||
|
||||
|
||||
private:
|
||||
BlockScaledGemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
|
||||
description_.kind = OperationKind::kBlockScaledGemm;
|
||||
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
|
||||
description_.SFA.layout = LayoutTypeID::kRowMajor;
|
||||
description_.SFA.alignment = 128;
|
||||
description_.SFA.log_extent_range = 32;
|
||||
description_.SFA.log_stride_range = 32;
|
||||
|
||||
description_.SFB.element = NumericTypeMap<ElementSFB>::kId;
|
||||
description_.SFB.layout = LayoutTypeID::kRowMajor;
|
||||
description_.SFB.alignment = 128;
|
||||
description_.SFB.log_extent_range = 32;
|
||||
description_.SFB.log_stride_range = 32;
|
||||
|
||||
description_.SFVecSize = SFVecSize;
|
||||
|
||||
description_.SFD = make_TensorDescription<ElementSFD, LayoutSFD>(128);
|
||||
description_.EpilogueSFVecSize = SFD_VectorSize;
|
||||
|
||||
|
||||
description_.name = name;
|
||||
description_.provider = Provider::kCUTLASS;
|
||||
description_.gemm_kind = GemmKind::kUniversal;
|
||||
|
||||
description_.tile_description.threadblock_shape = make_Coord(
|
||||
Operator::ThreadblockShape::kM,
|
||||
Operator::ThreadblockShape::kN,
|
||||
Operator::ThreadblockShape::kK);
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
description_.tile_description.cluster_shape = make_Coord(
|
||||
Operator::ClusterShape::kM,
|
||||
Operator::ClusterShape::kN,
|
||||
Operator::ClusterShape::kK);
|
||||
}
|
||||
|
||||
description_.tile_description.threadblock_stages = Operator::kStages;
|
||||
|
||||
description_.tile_description.warp_count = make_Coord(
|
||||
Operator::WarpCount::kM,
|
||||
Operator::WarpCount::kN,
|
||||
Operator::WarpCount::kK);
|
||||
|
||||
description_.tile_description.math_instruction.instruction_shape = make_Coord(
|
||||
Operator::InstructionShape::kM,
|
||||
Operator::InstructionShape::kN,
|
||||
Operator::InstructionShape::kK);
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.opcode_class =
|
||||
OpcodeClassMap<typename Operator::OperatorClass>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.math_operation =
|
||||
MathOperationMap<typename Operator::MathOperator>::kId;
|
||||
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
|
||||
|
||||
description_.tile_description.maximum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
|
||||
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
|
||||
description_.D = make_TensorDescription<ElementD, LayoutD>(Operator::kAlignmentD);
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.split_k_mode = SplitKMode::kNone;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
virtual OperationDescription const & description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
BlockScaledGemmDescription const& get_gemm_description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
|
||||
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
|
||||
// Do nothing here and construct kernel arguments in update_arguments_ instead
|
||||
// We also cannot construct TMA descriptors without all the arguments available
|
||||
|
||||
operator_args.mode = configuration->mode;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) {
|
||||
|
||||
if constexpr (epilogue_scalefactor_generation) {
|
||||
fusion_args.block_scale_factor_ptr = static_cast<ElementSFD*>(arguments.SFD);
|
||||
fusion_args.norm_constant_ptr = static_cast<ElementCompute const *>(arguments.norm_constant);
|
||||
}
|
||||
|
||||
|
||||
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
|
||||
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
BlockScaledGemmArguments const *arguments) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread, *arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
operator_args.problem_shape = cute::make_shape(
|
||||
arguments->problem_size.m(),
|
||||
arguments->problem_size.n(),
|
||||
arguments->problem_size.k(),
|
||||
arguments->batch_count);
|
||||
|
||||
// update arguments
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
|
||||
|
||||
using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA;
|
||||
using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB;
|
||||
|
||||
static_assert(cute::is_same_v<RuntimeDataTypeA, RuntimeDataTypeB>,
|
||||
"RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format");
|
||||
using RuntimeDatatypeArg = RuntimeDataTypeA;
|
||||
|
||||
auto mapping = [](RuntimeDatatype type) {
|
||||
if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF8F6F4Format>) {
|
||||
if (type == RuntimeDatatype::kE3M2) {
|
||||
return cute::UMMA::MXF8F6F4Format::E3M2;
|
||||
} else if (type == RuntimeDatatype::kE2M3) {
|
||||
return cute::UMMA::MXF8F6F4Format::E2M3;
|
||||
} else if (type == RuntimeDatatype::kE2M1) {
|
||||
return cute::UMMA::MXF8F6F4Format::E2M1;
|
||||
} else {
|
||||
assert("Invalid input datatype.");
|
||||
}
|
||||
}
|
||||
else if constexpr (cute::is_same_v<RuntimeDatatypeArg, cute::UMMA::MXF4Format>) {
|
||||
if (type == RuntimeDatatype::kE2M1) {
|
||||
return cute::UMMA::MXF4Format::E2M1;
|
||||
} else {
|
||||
assert("Invalid input datatype.");
|
||||
}
|
||||
}
|
||||
// BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
};
|
||||
|
||||
operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a);
|
||||
operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b);
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
|
||||
}
|
||||
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
|
||||
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
|
||||
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
arguments->lda, arguments->batch_stride_A);
|
||||
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
|
||||
arguments->ldb, arguments->batch_stride_B);
|
||||
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
|
||||
arguments->ldc, arguments->batch_stride_C);
|
||||
operator_args.epilogue.dD = operator_args.epilogue.dC;
|
||||
|
||||
operator_args.mainloop.layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape);
|
||||
operator_args.mainloop.layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape);
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
operator_args.hw_info.sm_count = arguments->sm_count;
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
|
||||
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
|
||||
}
|
||||
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
|
||||
using Enum_t = decltype(operator_args.scheduler.raster_order);
|
||||
switch (arguments->raster_order) {
|
||||
case RasterOrder::kAlongN:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongN;
|
||||
break;
|
||||
case RasterOrder::kAlongM:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongM;
|
||||
break;
|
||||
default:
|
||||
operator_args.scheduler.raster_order = Enum_t::Heuristic;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
||||
operator_args.scheduler.splits = arguments->split_k_slices;
|
||||
}
|
||||
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape = dim3(
|
||||
arguments->cluster_shape.m(),
|
||||
arguments->cluster_shape.n(),
|
||||
arguments->cluster_shape.k());
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(
|
||||
arguments->cluster_shape_fallback.m(),
|
||||
arguments->cluster_shape_fallback.n(),
|
||||
arguments->cluster_shape_fallback.k());
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement(
|
||||
void const *configuration_ptr, void const *arguments_ptr) const override {
|
||||
|
||||
GemmUniversalConfiguration const *configuration =
|
||||
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
|
||||
BlockScaledGemmArguments const *arguments =
|
||||
static_cast<BlockScaledGemmArguments const *>(arguments_ptr);
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(args, arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// can_implement rules may need access to problem shape
|
||||
args.problem_shape = cute::make_shape(
|
||||
configuration->problem_size.m(),
|
||||
configuration->problem_size.n(),
|
||||
configuration->problem_size.k(),
|
||||
configuration->batch_count);
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
uint64_t get_host_workspace_size(void const *configuration) const override {
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr,void const *arguments_ptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(
|
||||
args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
Status initialize_with_profiler_workspace(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
uint8_t **profiler_workspaces,
|
||||
int problem_count_from_profiler,
|
||||
cudaStream_t stream = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel
|
||||
Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
Status status = update_arguments_(args, static_cast<BlockScaledGemmArguments const *>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
|
||||
status = op->run(args, device_workspace, stream, nullptr, static_cast<BlockScaledGemmArguments const *>(arguments_ptr)->use_pdl);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::library
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -162,6 +162,18 @@ public:
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
||||
|
||||
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
||||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
|
||||
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
|
||||
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
@ -235,8 +247,42 @@ protected:
|
||||
arguments->batch_count);
|
||||
|
||||
// update arguments
|
||||
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
|
||||
|
||||
std::unordered_map<RuntimeDatatype, cute::UMMA::MXF8F6F4Format> mapping = {
|
||||
{RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3},
|
||||
{RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2},
|
||||
{RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2},
|
||||
{RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1}
|
||||
};
|
||||
|
||||
auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a);
|
||||
auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b);
|
||||
|
||||
if (iter_runtime_a != mapping.end()) {
|
||||
operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second;
|
||||
} else {
|
||||
assert("invalid runtime argument for datatype A!");
|
||||
}
|
||||
|
||||
if (iter_runtime_b != mapping.end()) {
|
||||
operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second;
|
||||
} else {
|
||||
assert("invalid runtime argument for datatype B!");
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
|
||||
}
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
|
||||
|
||||
@ -277,6 +323,22 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
||||
operator_args.scheduler.splits = arguments->split_k_slices;
|
||||
}
|
||||
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape = dim3(
|
||||
arguments->cluster_shape.m(),
|
||||
arguments->cluster_shape.n(),
|
||||
arguments->cluster_shape.k());
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(
|
||||
arguments->cluster_shape_fallback.m(),
|
||||
arguments->cluster_shape_fallback.n(),
|
||||
arguments->cluster_shape_fallback.k());
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
@ -510,6 +510,15 @@ Status Handle::gemm_universal(
|
||||
int M, /// GEMM M dimension
|
||||
int N, /// GEMM N dimension
|
||||
int K, /// GEMM K dimension
|
||||
|
||||
int cluster_m, /// cluster shape M dimension
|
||||
int cluster_n, /// cluster shape N dimension
|
||||
int cluster_k, /// cluster shape K dimension
|
||||
int cluster_m_fallback, /// Fallback cluster shape M dimension
|
||||
int cluster_n_fallback, /// Fallback cluster shape N dimension
|
||||
int cluster_k_fallback, /// Fallback cluster shape K dimension
|
||||
|
||||
|
||||
NumericTypeID element_compute, /// Data type of internal accumulation
|
||||
|
||||
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
||||
@ -629,6 +638,8 @@ Status Handle::gemm_universal(
|
||||
GemmUniversalConfiguration configuration{
|
||||
mode,
|
||||
{M, N, K},
|
||||
{cluster_m, cluster_n, cluster_k},
|
||||
{cluster_m_fallback, cluster_n_fallback, cluster_k_fallback},
|
||||
batch_count,
|
||||
lda,
|
||||
ldb,
|
||||
@ -647,6 +658,8 @@ Status Handle::gemm_universal(
|
||||
|
||||
GemmUniversalArguments arguments{
|
||||
{M, N, K},
|
||||
{cluster_m, cluster_n, cluster_k},
|
||||
{cluster_m_fallback, cluster_n_fallback, cluster_k_fallback},
|
||||
batch_count,
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
|
||||
@ -116,6 +116,27 @@ template <> struct NumericTypeMap<cutlass::float_e5m2_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFE5M2;
|
||||
};
|
||||
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::float_e2m3_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFE2M3;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::float_e3m2_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFE3M2;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::float_e2m1_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFE2M1;
|
||||
};
|
||||
template <> struct NumericTypeMap<cutlass::float_ue8m0_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFUE8M0;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::float_ue4m3_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kFUE4M3;
|
||||
};
|
||||
|
||||
|
||||
template <> struct NumericTypeMap<uint16_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kU16;
|
||||
};
|
||||
@ -161,6 +182,21 @@ template <> struct NumericTypeMap<cutlass::tfloat32_t> {
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float8_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kF8;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float6_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kF6;
|
||||
};
|
||||
|
||||
template <> struct NumericTypeMap<cutlass::type_erased_dynamic_float4_t> {
|
||||
static NumericTypeID const kId = NumericTypeID::kF4;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T> struct MathOperationMap {
|
||||
@ -300,6 +336,12 @@ template <> struct OpcodeClassMap<arch::OpClassSparseTensorOp> {
|
||||
static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp;
|
||||
};
|
||||
|
||||
|
||||
template <> struct OpcodeClassMap<arch::OpClassBlockScaledTensorOp> {
|
||||
static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp;
|
||||
};
|
||||
|
||||
|
||||
template <> struct OpcodeClassMap<arch::OpClassWmmaTensorOp> {
|
||||
static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp;
|
||||
};
|
||||
|
||||
@ -48,6 +48,45 @@ void OperationTable::append(Manifest const &manifest) {
|
||||
// Insert operations into appropriate data structure
|
||||
for (auto const & operation : manifest) {
|
||||
OperationDescription const &desc = operation->description();
|
||||
|
||||
if (desc.kind == OperationKind::kBlockScaledGemm) {
|
||||
BlockScaledGemmDescription const &gemm_desc = static_cast<BlockScaledGemmDescription const &>(desc);
|
||||
|
||||
BlockScaledGemmFunctionalKey functional_key(
|
||||
gemm_desc.provider,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.SFA.element,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
gemm_desc.SFD.element,
|
||||
gemm_desc.SFD.layout,
|
||||
gemm_desc.SFVecSize
|
||||
, gemm_desc.EpilogueSFVecSize
|
||||
);
|
||||
|
||||
Operation const *op = operation.get();
|
||||
|
||||
int cc = gemm_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
int alignment = std::max(std::max(
|
||||
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
|
||||
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
|
||||
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
|
||||
|
||||
// insert all gemm operation into operation table
|
||||
if (desc.kind == OperationKind::kGemm) {
|
||||
GemmDescription const &gemm_desc = static_cast<GemmDescription const &>(desc);
|
||||
|
||||
128
tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu
Normal file
128
tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu
Normal file
@ -0,0 +1,128 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "block_scaled_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SFVectorSize = 16 with MxF4NvF4 instructions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// (float_e2m1_t * float_ue4m3_t) * (float_e2m1_t * float_ue4m3_t)
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
16 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
// (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
16 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
16 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm_tn<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
130
tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu
Normal file
130
tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu
Normal file
@ -0,0 +1,130 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "block_scaled_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest) {
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SFVectorSize = 32 with MxF4 instructions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// With SF generation reference
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
|
||||
16 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
|
||||
16 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpiSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
354
tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu
Normal file
354
tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu
Normal file
@ -0,0 +1,354 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "block_scaled_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SFVectorSize = 32 with MxF8F6F4 instructions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// (float_e2m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// (float_e4m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// (float_e2m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// (float_e2m1_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// (float_e4m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
// (float_e4m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
|
||||
// (float_e3m2_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
// (float_e2m1_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
|
||||
// (float_e2m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t)
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
make_block_scaled_gemm<
|
||||
float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/,
|
||||
half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/,
|
||||
32 /*EpilogueSFVecSize*/
|
||||
>(manifest);
|
||||
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,459 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "library_internal.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
Provider Provider_,
|
||||
typename ElementA_,
|
||||
typename LayoutA_,
|
||||
typename ElementSFA_,
|
||||
typename ElementB_,
|
||||
typename LayoutB_,
|
||||
typename ElementSFB_,
|
||||
typename ElementC_,
|
||||
typename LayoutC_,
|
||||
typename ElementCompute_,
|
||||
typename ElementAccumulator_ = ElementCompute_,
|
||||
typename ElementD_ = ElementC_,
|
||||
typename ElementSFD_ = void,
|
||||
typename LayoutSFD_ = LayoutC_,
|
||||
int SFVecSize_ = 32,
|
||||
int EpilogueSFVecSize_ = 0,
|
||||
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
|
||||
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
||||
>
|
||||
class BlockScaledGemmReferenceOperation : public Operation {
|
||||
public:
|
||||
static Provider const kProvider = Provider_;
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using ElementSFA = ElementSFA_;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using ElementSFB = ElementSFB_;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementSFD = ElementSFD_;
|
||||
using LayoutSFD = LayoutSFD_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ConvertOp = ConvertOp_;
|
||||
using InnerProductOp = InnerProductOp_;
|
||||
constexpr static int SFVecSize = SFVecSize_;
|
||||
constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_;
|
||||
|
||||
protected:
|
||||
|
||||
/// Storage for the name string
|
||||
std::string name_;
|
||||
|
||||
///
|
||||
BlockScaledGemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockScaledGemmReferenceOperation() {
|
||||
|
||||
// Basic information
|
||||
description_.provider = kProvider;
|
||||
description_.kind = OperationKind::kBlockScaledGemm;
|
||||
description_.gemm_kind = GemmKind::kUniversal;
|
||||
|
||||
// Tensor description
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>();
|
||||
description_.SFA = make_TensorDescription<ElementSFA, LayoutA>();
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>();
|
||||
description_.SFB = make_TensorDescription<ElementSFB, LayoutB>();
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>();
|
||||
description_.D = make_TensorDescription<ElementD, LayoutC>();
|
||||
description_.SFD = make_TensorDescription<ElementSFD, LayoutSFD>();
|
||||
|
||||
// Epilogue compute and accumulator type description
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
|
||||
// Compute capability for gemm reference
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
(kProvider == Provider::kReferenceDevice ? 50 : 0);
|
||||
|
||||
description_.tile_description.maximum_compute_capability = 1024;
|
||||
|
||||
description_.SFVecSize = SFVecSize;
|
||||
description_.EpilogueSFVecSize = EpilogueSFVecSize;
|
||||
|
||||
// Procedural name
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "gemm"
|
||||
<< "_reference_" << to_string(description_.provider)
|
||||
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
|
||||
<< "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout)
|
||||
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
|
||||
<< "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout)
|
||||
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
|
||||
<< "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout)
|
||||
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
|
||||
|
||||
name_ = ss.str();
|
||||
|
||||
description_.name = name_.c_str();
|
||||
|
||||
// Epilogue compute and accumulator type description
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
virtual OperationDescription const & description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
virtual Status can_implement(
|
||||
void const *configuration,
|
||||
void const *arguments) const {
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
virtual uint64_t get_host_workspace_size(
|
||||
void const *configuration) const {
|
||||
|
||||
return sizeof(GemmUniversalConfiguration);
|
||||
}
|
||||
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration,
|
||||
void const *arguments = nullptr) const {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual Status initialize(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
virtual Status run(
|
||||
void const *arguments,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
using namespace cute;
|
||||
|
||||
BlockScaledGemmArguments const &args = *static_cast<BlockScaledGemmArguments const *>(arguments);
|
||||
|
||||
// Construct cute::Tensor A/B/C
|
||||
|
||||
int M = args.problem_size.m();
|
||||
int N = args.problem_size.n();
|
||||
int K = args.problem_size.k();
|
||||
int L = args.batch_count;
|
||||
|
||||
auto problem_shape_MNKL = cute::make_shape(M, N, K, L);
|
||||
|
||||
auto alpha = *(static_cast<ElementCompute const*>(args.alpha));
|
||||
auto beta = *(static_cast<ElementCompute const*>(args.beta));
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
|
||||
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutC>;
|
||||
|
||||
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
using Sm100BlockScaledConfig = cutlass::detail::Sm100BlockScaledConfig<SFVecSize>;
|
||||
auto A = cute::make_tensor(detail::make_iterator(static_cast<ElementA const*>(args.A)),
|
||||
cute::make_layout(cute::make_shape(M, K, L), stride_a));
|
||||
auto SfA = make_tensor(static_cast<ElementSFA const*>(args.SFA), Sm100BlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL));
|
||||
|
||||
auto B = cute::make_tensor(detail::make_iterator(static_cast<ElementB const*>(args.B)),
|
||||
cute::make_layout(cute::make_shape(N, K, L), stride_b));
|
||||
auto SfB = make_tensor(static_cast<ElementSFB const*>(args.SFB), Sm100BlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL));
|
||||
|
||||
auto C = [&]() {
|
||||
if constexpr (not is_same_v<ElementC, void>) {
|
||||
return cute::make_tensor(detail::make_iterator(static_cast<ElementC const*>(args.C)),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
||||
}
|
||||
else {
|
||||
return cute::make_tensor(detail::make_iterator(static_cast<ElementD const*>(nullptr)),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
||||
}
|
||||
}();
|
||||
|
||||
auto D = cute::make_tensor(detail::make_iterator(static_cast<ElementD *>(args.D)),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(SfA),
|
||||
decltype(B), decltype(SfB)>
|
||||
mainloop_params{A, SfA, B, SfB};
|
||||
|
||||
if constexpr (not is_same_v<ElementSFD, void>) {
|
||||
|
||||
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
|
||||
EpilogueSFVecSize
|
||||
>;
|
||||
|
||||
auto SfD = cute::make_tensor(detail::make_iterator(static_cast<ElementSFD*>(args.SFD)), Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL));
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, ElementAccumulator, ElementCompute,
|
||||
decltype(C), decltype(D), decltype(SfD), Int<EpilogueSFVecSize>, cutlass::reference::host::SfStrategy::SfDGen>
|
||||
epilogue_params{alpha, beta, C, D, SfD, *(static_cast<ElementCompute const*>(args.norm_constant))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
else {
|
||||
// W/O SF generation
|
||||
auto SfD = cute::make_tensor(static_cast<ElementSFA *>(nullptr),
|
||||
cute::make_layout(cute::make_shape(M, N, L))); // not used.
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, ElementAccumulator, ElementCompute,
|
||||
decltype(C), decltype(D), decltype(SfD)>
|
||||
epilogue_params{alpha, beta, C, D, SfD};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA_,
|
||||
typename ElementSFA_,
|
||||
typename ElementB_,
|
||||
typename ElementSFB_,
|
||||
typename ElementC_,
|
||||
typename ElementCompute_,
|
||||
typename ElementSFD_ = void,
|
||||
typename ElementAccumulator_ = ElementCompute_,
|
||||
typename ElementD_ = ElementC_,
|
||||
int SFVecSize = 32,
|
||||
int EpilogueSFVecSize = SFVecSize,
|
||||
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
|
||||
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
||||
>
|
||||
void make_block_scaled_gemm_tn(Manifest &manifest) {
|
||||
#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE)
|
||||
manifest.append(new BlockScaledGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ElementSFD_,
|
||||
cutlass::layout::RowMajor,
|
||||
SFVecSize,
|
||||
EpilogueSFVecSize,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>);
|
||||
#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA_,
|
||||
typename ElementSFA_,
|
||||
typename ElementB_,
|
||||
typename ElementSFB_,
|
||||
typename ElementC_,
|
||||
typename ElementCompute_,
|
||||
typename ElementSFD_ = void,
|
||||
typename ElementAccumulator_ = ElementCompute_,
|
||||
typename ElementD_ = ElementC_,
|
||||
int SFVecSize = 32,
|
||||
int EpilogueSFVecSize = SFVecSize,
|
||||
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
|
||||
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
||||
>
|
||||
void make_block_scaled_gemm(Manifest &manifest) {
|
||||
///
|
||||
/// A is Row , B is Col
|
||||
///
|
||||
manifest.append(new BlockScaledGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ElementSFD_,
|
||||
cutlass::layout::RowMajor,
|
||||
SFVecSize,
|
||||
EpilogueSFVecSize,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>);
|
||||
manifest.append(new BlockScaledGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ElementSFD_,
|
||||
cutlass::layout::RowMajor,
|
||||
SFVecSize,
|
||||
EpilogueSFVecSize,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>);
|
||||
///
|
||||
/// A is Col , B is Row
|
||||
///
|
||||
manifest.append(new BlockScaledGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ElementSFD_,
|
||||
cutlass::layout::RowMajor,
|
||||
SFVecSize,
|
||||
EpilogueSFVecSize,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>);
|
||||
manifest.append(new BlockScaledGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ElementSFD_,
|
||||
cutlass::layout::RowMajor,
|
||||
SFVecSize,
|
||||
EpilogueSFVecSize,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
109
tools/library/src/reference/gemm_f4_f4_f32.cu
Normal file
109
tools/library/src/reference/gemm_f4_f4_f32.cu
Normal file
@ -0,0 +1,109 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A/B : float_e2m1_t (not support float_e0m2_t to reduce ref kernel compile time)
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e2m1_e2m1_f32_f16_e4m3
|
||||
// 2. e2m1_e2m1_f32_f16_e5m2
|
||||
// 3. e2m1_e2m1_f32_f16_f16
|
||||
// 4. e2m1_e2m1_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f4_f6_f32.cu
Normal file
110
tools/library/src/reference/gemm_f4_f6_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e2m1_t
|
||||
// B: float_e3m2_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e2m1_e3m2_f32_f16_e4m3
|
||||
// 2. e2m1_e3m2_f32_f16_e5m2
|
||||
// 3. e2m1_e3m2_f32_f16_f16
|
||||
// 4. e2m1_e3m2_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f4_f8_f32.cu
Normal file
110
tools/library/src/reference/gemm_f4_f8_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e2m1_t
|
||||
// B: float_e4m3_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e2m1_e4m3_f32_f16_e4m3
|
||||
// 2. e2m1_e4m3_f32_f16_e5m2
|
||||
// 3. e2m1_e4m3_f32_f16_f16
|
||||
// 4. e2m1_e4m3_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e2m1_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f6_f4_f32.cu
Normal file
110
tools/library/src/reference/gemm_f6_f4_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e3m2_t
|
||||
// B: float_e2m1_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e3m2_e2m1_f32_f16_e4m3
|
||||
// 2. e3m2_e2m1_f32_f16_e5m2
|
||||
// 3. e3m2_e2m1_f32_f16_f16
|
||||
// 4. e3m2_e2m1_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
109
tools/library/src/reference/gemm_f6_f6_f32.cu
Normal file
109
tools/library/src/reference/gemm_f6_f6_f32.cu
Normal file
@ -0,0 +1,109 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A/B : float_e3m2_t (not support float_e2m3_t to reduce ref kernel compile time)
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e3m2_e3m2_f32_f16_e4m3
|
||||
// 2. e3m2_e3m2_f32_f16_e5m2
|
||||
// 3. e3m2_e3m2_f32_f16_f16
|
||||
// 4. e3m2_e3m2_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f6_f8_f32.cu
Normal file
110
tools/library/src/reference/gemm_f6_f8_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e3m2_t
|
||||
// B: float_e4m3_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e3m2_e4m3_f32_f16_e4m3
|
||||
// 2. e3m2_e4m3_f32_f16_e5m2
|
||||
// 3. e3m2_e4m3_f32_f16_f16
|
||||
// 4. e3m2_e4m3_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e3m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f8_f4_f32.cu
Normal file
110
tools/library/src/reference/gemm_f8_f4_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e4m3_t
|
||||
// B: float_e2m1_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e4m3_e2m1_f32_f16_e4m3
|
||||
// 2. e4m3_e2m1_f32_f16_e5m2
|
||||
// 3. e4m3_e2m1_f32_f16_f16
|
||||
// 4. e4m3_e2m1_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e2m1_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
110
tools/library/src/reference/gemm_f8_f6_f32.cu
Normal file
110
tools/library/src/reference/gemm_f8_f6_f32.cu
Normal file
@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A: float_e4m3_t
|
||||
// B: float_e3m2_t
|
||||
// Acc: f32
|
||||
// C/D : some variance
|
||||
|
||||
// 1. e4m3_e3m2_f32_f16_e4m3
|
||||
// 2. e4m3_e3m2_f32_f16_e5m2
|
||||
// 3. e4m3_e3m2_f32_f16_f16
|
||||
// 4. e4m3_e3m2_f32_f32_f32
|
||||
|
||||
void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest) {
|
||||
|
||||
// 1.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 2.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 3.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e3m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -87,6 +87,17 @@ void initialize_gemm_reference_operations_u8_u8_s32(Manifest &manifest) {
|
||||
NumericConverterClamp<int8_t, float> // From Scalar to D
|
||||
>(manifest);
|
||||
|
||||
// 4.
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t, // ElementA
|
||||
uint8_t, // ElementB
|
||||
int8_t, // ElementC
|
||||
float, // ElementScalar / ElementCompute
|
||||
int32_t, // ElementAccumulator
|
||||
uint8_t, // ElementD
|
||||
NumericConverterClamp<uint8_t, float> // From Scalar to D
|
||||
>(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -52,6 +52,19 @@ void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest);
|
||||
|
||||
void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
|
||||
|
||||
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||
@ -89,6 +102,19 @@ void initialize_reference_operations(Manifest &manifest) {
|
||||
initialize_gemm_reference_operations_fp_mixed_input(manifest);
|
||||
initialize_gemm_reference_operations_int_mixed_input(manifest);
|
||||
|
||||
|
||||
initialize_gemm_reference_operations_f4_f4_f32(manifest);
|
||||
initialize_gemm_reference_operations_f4_f6_f32(manifest);
|
||||
initialize_gemm_reference_operations_f4_f8_f32(manifest);
|
||||
initialize_gemm_reference_operations_f6_f4_f32(manifest);
|
||||
initialize_gemm_reference_operations_f6_f6_f32(manifest);
|
||||
initialize_gemm_reference_operations_f6_f8_f32(manifest);
|
||||
initialize_gemm_reference_operations_f8_f4_f32(manifest);
|
||||
initialize_gemm_reference_operations_f8_f6_f32(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -211,6 +211,10 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
||||
operator_args.scheduler.splits = arguments->split_k_slices;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
@ -334,6 +334,7 @@ static struct {
|
||||
OperationKind_enumerants[] = {
|
||||
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
|
||||
{"gemm", "Gemm", OperationKind::kGemm},
|
||||
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
|
||||
{"rank_k", "RankK", OperationKind::kRankK},
|
||||
{"rank_2k", "Rank2K", OperationKind::kRank2K},
|
||||
{"trmm", "Trmm", OperationKind::kTrmm},
|
||||
@ -422,6 +423,53 @@ Status from_string<Status>(std::string const &str) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
RuntimeDatatype enumerant;
|
||||
}
|
||||
RuntimeDatatype_enumerants[] = {
|
||||
{"e4m3", "<e4m3>", RuntimeDatatype::kE4M3},
|
||||
{"e5m2", "<e5m2>", RuntimeDatatype::kE5M2},
|
||||
{"e3m2", "<e3m2>", RuntimeDatatype::kE3M2},
|
||||
{"e2m3", "<e2m3>", RuntimeDatatype::kE2M3},
|
||||
{"e2m1", "<e2m1>", RuntimeDatatype::kE2M1}
|
||||
};
|
||||
|
||||
/// Converts a RuntimeDatatype enumerant to a string
|
||||
char const *to_string(RuntimeDatatype type, bool pretty) {
|
||||
|
||||
for (auto const & possible : RuntimeDatatype_enumerants) {
|
||||
if (type == possible.enumerant) {
|
||||
if (pretty) {
|
||||
return possible.pretty;
|
||||
}
|
||||
else {
|
||||
return possible.text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pretty ? "Invalid" : "invalid";
|
||||
}
|
||||
|
||||
|
||||
/// Converts a RuntimeDatatype enumerant from a string
|
||||
template <>
|
||||
RuntimeDatatype from_string<RuntimeDatatype>(std::string const &str) {
|
||||
|
||||
for (auto const & possible : RuntimeDatatype_enumerants) {
|
||||
if ((str.compare(possible.text) == 0) ||
|
||||
(str.compare(possible.pretty) == 0)) {
|
||||
return possible.enumerant;
|
||||
}
|
||||
}
|
||||
|
||||
return RuntimeDatatype::kInvalid;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct {
|
||||
@ -447,6 +495,16 @@ NumericTypeID_enumerants[] = {
|
||||
{"s64", "S64", NumericTypeID::kS64},
|
||||
{"fe4m3", "FE4M3", NumericTypeID::kFE4M3},
|
||||
{"fe5m2", "FE5M2", NumericTypeID::kFE5M2},
|
||||
|
||||
{"f8", "F8", NumericTypeID::kF8},
|
||||
{"f6", "F6", NumericTypeID::kF6},
|
||||
{"f4", "F4", NumericTypeID::kF4},
|
||||
{"fe2m3", "FE2M3", NumericTypeID::kFE2M3},
|
||||
{"fe3m2", "FE3M2", NumericTypeID::kFE3M2},
|
||||
{"fe2m1", "FE2M1", NumericTypeID::kFE2M1},
|
||||
{"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0},
|
||||
{"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3},
|
||||
|
||||
{"f16", "F16", NumericTypeID::kF16},
|
||||
{"bf16", "BF16", NumericTypeID::kBF16},
|
||||
{"f32", "F32", NumericTypeID::kF32},
|
||||
@ -510,6 +568,16 @@ int sizeof_bits(NumericTypeID type) {
|
||||
switch (type) {
|
||||
case NumericTypeID::kFE4M3: return 8;
|
||||
case NumericTypeID::kFE5M2: return 8;
|
||||
|
||||
case NumericTypeID::kF8: return 8;
|
||||
case NumericTypeID::kF6: return 6;
|
||||
case NumericTypeID::kF4: return 4;
|
||||
case NumericTypeID::kFE2M3: return 6;
|
||||
case NumericTypeID::kFE3M2: return 6;
|
||||
case NumericTypeID::kFE2M1: return 4;
|
||||
case NumericTypeID::kFUE8M0: return 8;
|
||||
case NumericTypeID::kFUE4M3: return 8;
|
||||
|
||||
case NumericTypeID::kF16: return 16;
|
||||
case NumericTypeID::kBF16: return 16;
|
||||
case NumericTypeID::kTF32: return 32;
|
||||
@ -589,6 +657,16 @@ bool is_signed_type(NumericTypeID type) {
|
||||
switch (type) {
|
||||
case NumericTypeID::kFE4M3: return true;
|
||||
case NumericTypeID::kFE5M2: return true;
|
||||
|
||||
case NumericTypeID::kF8: return true;
|
||||
case NumericTypeID::kF6: return true;
|
||||
case NumericTypeID::kF4: return true;
|
||||
case NumericTypeID::kFE2M3: return true;
|
||||
case NumericTypeID::kFE3M2: return true;
|
||||
case NumericTypeID::kFE2M1: return true;
|
||||
case NumericTypeID::kFUE8M0: return false;
|
||||
case NumericTypeID::kFUE4M3: return false;
|
||||
|
||||
case NumericTypeID::kF16: return true;
|
||||
case NumericTypeID::kBF16: return true;
|
||||
case NumericTypeID::kTF32: return true;
|
||||
@ -620,6 +698,16 @@ bool is_float_type(NumericTypeID type) {
|
||||
switch (type) {
|
||||
case NumericTypeID::kFE4M3: return true;
|
||||
case NumericTypeID::kFE5M2: return true;
|
||||
|
||||
case NumericTypeID::kF8: return true;
|
||||
case NumericTypeID::kF6: return true;
|
||||
case NumericTypeID::kF4: return true;
|
||||
case NumericTypeID::kFE2M3: return true;
|
||||
case NumericTypeID::kFE3M2: return true;
|
||||
case NumericTypeID::kFE2M1: return true;
|
||||
case NumericTypeID::kFUE8M0: return true;
|
||||
case NumericTypeID::kFUE4M3: return true;
|
||||
|
||||
case NumericTypeID::kF16: return true;
|
||||
case NumericTypeID::kBF16: return true;
|
||||
case NumericTypeID::kTF32: return true;
|
||||
@ -1168,6 +1256,43 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
|
||||
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(tmp);
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
float tmp;
|
||||
ss >> tmp;
|
||||
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(tmp);
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE3M2:
|
||||
{
|
||||
float tmp;
|
||||
ss >> tmp;
|
||||
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(tmp);
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE2M1:
|
||||
{
|
||||
float tmp;
|
||||
ss >> tmp;
|
||||
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(tmp);
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE8M0:
|
||||
{
|
||||
float tmp;
|
||||
ss >> tmp;
|
||||
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(tmp);
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE4M3:
|
||||
{
|
||||
float tmp;
|
||||
ss >> tmp;
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(tmp);
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
float tmp;
|
||||
@ -1317,6 +1442,38 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
float tmp = *reinterpret_cast<float_e2m3_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE3M2:
|
||||
{
|
||||
float tmp = *reinterpret_cast<float_e3m2_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE2M1:
|
||||
{
|
||||
float tmp = *reinterpret_cast<float_e2m1_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE8M0:
|
||||
{
|
||||
float tmp = *reinterpret_cast<float_ue8m0_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE4M3:
|
||||
{
|
||||
float tmp = *reinterpret_cast<float_ue4m3_t *>(bytes.data());
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
float tmp = *reinterpret_cast<half_t *>(bytes.data());
|
||||
@ -1469,6 +1626,33 @@ bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t sr
|
||||
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE3M2:
|
||||
{
|
||||
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE2M1:
|
||||
{
|
||||
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE8M0:
|
||||
{
|
||||
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE4M3:
|
||||
{
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
@ -1579,6 +1763,33 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
|
||||
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE3M2:
|
||||
{
|
||||
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE2M1:
|
||||
{
|
||||
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE8M0:
|
||||
{
|
||||
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE4M3:
|
||||
{
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
@ -1690,6 +1901,33 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
|
||||
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
*reinterpret_cast<float_e2m3_t *>(bytes.data()) = static_cast<float_e2m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE3M2:
|
||||
{
|
||||
*reinterpret_cast<float_e3m2_t *>(bytes.data()) = static_cast<float_e3m2_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFE2M1:
|
||||
{
|
||||
*reinterpret_cast<float_e2m1_t *>(bytes.data()) = static_cast<float_e2m1_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE8M0:
|
||||
{
|
||||
*reinterpret_cast<float_ue8m0_t *>(bytes.data()) = static_cast<float_ue8m0_t>(float(src));
|
||||
}
|
||||
break;
|
||||
case NumericTypeID::kFUE4M3:
|
||||
{
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
@ -1751,6 +1989,35 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type) {
|
||||
NumericTypeID element{};
|
||||
switch (type) {
|
||||
case RuntimeDatatype::kE4M3:
|
||||
element = NumericTypeID::kFE4M3;
|
||||
break;
|
||||
case RuntimeDatatype::kE5M2:
|
||||
element = NumericTypeID::kFE5M2;
|
||||
break;
|
||||
|
||||
case RuntimeDatatype::kE2M3:
|
||||
element = NumericTypeID::kFE2M3;
|
||||
break;
|
||||
case RuntimeDatatype::kE3M2:
|
||||
element = NumericTypeID::kFE3M2;
|
||||
break;
|
||||
case RuntimeDatatype::kE2M1:
|
||||
element = NumericTypeID::kFE2M1;
|
||||
break;
|
||||
|
||||
default:
|
||||
assert("illegal runtime datatype!");
|
||||
break;
|
||||
}
|
||||
return element;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
|
||||
@ -46,6 +46,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
|
||||
src/problem_space.cpp
|
||||
src/operation_profiler.cu
|
||||
src/gemm_operation_profiler.cu
|
||||
src/block_scaled_gemm_operation_profiler.cu
|
||||
src/rank_k_operation_profiler.cu
|
||||
src/rank_2k_operation_profiler.cu
|
||||
src/trmm_operation_profiler.cu
|
||||
@ -101,6 +102,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
else()
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true)
|
||||
|
||||
@ -0,0 +1,290 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Blockscale Gemm Profiler
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
// Profiler includes
|
||||
#include "options.h"
|
||||
#include "device_context.h"
|
||||
#include "operation_profiler.h"
|
||||
#include "performance_result.h"
|
||||
#include "problem_space.h"
|
||||
#include "reduction_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Abstract base class for each math function
|
||||
class BlockScaledGemmOperationProfiler : public OperationProfiler {
|
||||
public:
|
||||
|
||||
/// Problem structure obtained from problem space
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm};
|
||||
|
||||
int64_t m{16};
|
||||
int64_t n{16};
|
||||
int64_t k{16};
|
||||
|
||||
|
||||
int cluster_m{1};
|
||||
int cluster_n{1};
|
||||
int cluster_k{1};
|
||||
int cluster_m_fallback{1};
|
||||
int cluster_n_fallback{1};
|
||||
int cluster_k_fallback{1};
|
||||
|
||||
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone};
|
||||
int split_k_slices{1};
|
||||
int batch_count{1};
|
||||
|
||||
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
|
||||
int swizzle_size{1};
|
||||
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
|
||||
|
||||
|
||||
// gemm with parallel interleaved reduction
|
||||
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||
std::vector<uint8_t> alpha_one;
|
||||
std::vector<uint8_t> beta_zero;
|
||||
|
||||
bool use_pdl{false};
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Parses the problem
|
||||
Status parse(
|
||||
library::BlockScaledGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Total number of bytes loaded
|
||||
int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const;
|
||||
|
||||
/// Total number of flops computed
|
||||
int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const;
|
||||
|
||||
/// Initializes a performance result
|
||||
void initialize_result(
|
||||
PerformanceResult &result,
|
||||
library::BlockScaledGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space);
|
||||
};
|
||||
|
||||
/// Workspace used
|
||||
struct GemmWorkspace {
|
||||
|
||||
DeviceAllocation *A{nullptr};
|
||||
DeviceAllocation *SFA{nullptr};
|
||||
DeviceAllocation *B{nullptr};
|
||||
DeviceAllocation *SFB{nullptr};
|
||||
DeviceAllocation *C{nullptr};
|
||||
DeviceAllocation *Computed{nullptr};
|
||||
DeviceAllocation *Reference{nullptr};
|
||||
DeviceAllocation *Computed_SFD{nullptr};
|
||||
DeviceAllocation *Reference_SFD{nullptr};
|
||||
DeviceAllocation *Norm_constant{nullptr};
|
||||
|
||||
/// Number of copies of the problem workspace which are visited sequentially during
|
||||
/// profiling to avoid camping in the last level cache.
|
||||
int problem_count{1};
|
||||
|
||||
library::GemmUniversalConfiguration configuration;
|
||||
library::BlockScaledGemmArguments arguments;
|
||||
|
||||
/// Buffer used for the operation's host workspace
|
||||
std::vector<uint8_t> host_workspace;
|
||||
|
||||
/// Buffer used for the operations' device workspace
|
||||
DeviceAllocation device_workspace;
|
||||
|
||||
/// Library configuration and arguments for reduction operator
|
||||
library::ReductionConfiguration reduction_configuration;
|
||||
library::ReductionArguments reduction_arguments;
|
||||
|
||||
/// Buffer used for the cutlass reduction operations' host workspace
|
||||
std::vector<uint8_t> reduction_host_workspace;
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// GEMM problem obtained from problem space
|
||||
GemmProblem problem_;
|
||||
|
||||
/// Device memory allocations
|
||||
GemmWorkspace gemm_workspace_;
|
||||
|
||||
/// CUTLASS parallel reduction operation to follow this* gemm operation
|
||||
library::Operation const *reduction_op_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
BlockScaledGemmOperationProfiler(Options const &options);
|
||||
|
||||
/// Destructor
|
||||
virtual ~BlockScaledGemmOperationProfiler();
|
||||
|
||||
GemmProblem const& problem() const { return problem_; }
|
||||
|
||||
/// Prints usage statement for the math function
|
||||
virtual void print_usage(std::ostream &out) const;
|
||||
|
||||
/// Prints examples
|
||||
virtual void print_examples(std::ostream &out) const;
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
virtual Status initialize_configuration(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Initializes workspace
|
||||
virtual Status initialize_workspace(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
virtual bool verify_cutlass(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Measures performance results
|
||||
virtual bool profile(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
protected:
|
||||
|
||||
/// Initializes the performance result
|
||||
void initialize_result_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
library::BlockScaledGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
bool verify_with_cublas_(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool verify_with_reference_(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem,
|
||||
cutlass::library::NumericTypeID element_A,
|
||||
cutlass::library::NumericTypeID element_B);
|
||||
|
||||
/// Method to profile a CUTLASS Operation
|
||||
Status profile_cutlass_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
library::Operation const *operation,
|
||||
void *arguments,
|
||||
void *host_workspace,
|
||||
void *device_workspace);
|
||||
|
||||
/// Initialize reduction problem dimensions and library::Operation
|
||||
bool initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -73,6 +73,15 @@ public:
|
||||
int64_t n{16};
|
||||
int64_t k{16};
|
||||
|
||||
|
||||
int cluster_m{1};
|
||||
int cluster_n{1};
|
||||
int cluster_k{1};
|
||||
int cluster_m_fallback{1};
|
||||
int cluster_n_fallback{1};
|
||||
int cluster_k_fallback{1};
|
||||
|
||||
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
@ -86,6 +95,11 @@ public:
|
||||
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
|
||||
int swizzle_size{1};
|
||||
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
|
||||
|
||||
|
||||
// gemm with parallel interleaved reduction
|
||||
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||
|
||||
@ -942,6 +942,18 @@ bool arg_as_IteratorAlgorithmID(
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr);
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RuntimeDatatype(
|
||||
library::RuntimeDatatype &runtime_datatype,
|
||||
char const *name,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr);
|
||||
|
||||
|
||||
1371
tools/profiler/src/block_scaled_gemm_operation_profiler.cu
Normal file
1371
tools/profiler/src/block_scaled_gemm_operation_profiler.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,7 @@
|
||||
// Profiler includes
|
||||
#include "cutlass/profiler/cutlass_profiler.h"
|
||||
#include "cutlass/profiler/gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_k_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_2k_operation_profiler.h"
|
||||
#include "cutlass/profiler/trmm_operation_profiler.h"
|
||||
@ -60,6 +61,8 @@ CutlassProfiler::CutlassProfiler(
|
||||
|
||||
operation_profilers_.emplace_back(new GemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));
|
||||
|
||||
@ -616,6 +616,48 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_ue4m3_t>(
|
||||
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_ue8m0_t>(
|
||||
reinterpret_cast<cutlass::float_ue8m0_t *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_e3m2_t>(
|
||||
reinterpret_cast<cutlass::float_e3m2_t *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_e2m1_t>(
|
||||
reinterpret_cast<cutlass::float_e2m1_t *>(pointer_),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF64:
|
||||
cutlass::reference::device::BlockFillRandom<double>(
|
||||
reinterpret_cast<double *>(pointer_),
|
||||
@ -771,6 +813,50 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_ue4m3_t>(
|
||||
reinterpret_cast<cutlass::float_ue4m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_e3m2_t>(
|
||||
reinterpret_cast<cutlass::float_e3m2_t *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m1_t>(
|
||||
reinterpret_cast<cutlass::float_e2m1_t *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_ue8m0_t>(
|
||||
reinterpret_cast<cutlass::float_ue8m0_t *>(host_data.data()),
|
||||
capacity_,
|
||||
seed,
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(host_data.data()),
|
||||
@ -990,6 +1076,50 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
|
||||
static_cast<cutlass::float_e5m2_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_ue4m3_t>(
|
||||
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_ue4m3_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_ue4m3_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e2m3_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e2m3_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_e3m2_t>(
|
||||
reinterpret_cast<cutlass::float_e3m2_t *>(pointer_),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e3m2_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e3m2_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m1_t>(
|
||||
reinterpret_cast<cutlass::float_e2m1_t *>(pointer_),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e2m1_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e2m1_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_ue8m0_t>(
|
||||
reinterpret_cast<cutlass::float_ue8m0_t *>(pointer_),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(pointer_),
|
||||
@ -1220,6 +1350,50 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
|
||||
static_cast<cutlass::float_e5m2_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_ue4m3_t>(
|
||||
reinterpret_cast<cutlass::float_ue4m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_ue4m3_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_ue4m3_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e2m3_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e2m3_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_e3m2_t>(
|
||||
reinterpret_cast<cutlass::float_e3m2_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e3m2_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e3m2_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m1_t>(
|
||||
reinterpret_cast<cutlass::float_e2m1_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_e2m1_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_e2m1_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_ue8m0_t>(
|
||||
reinterpret_cast<cutlass::float_ue8m0_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.delta),
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(host_data.data()),
|
||||
@ -1516,6 +1690,34 @@ bool DeviceAllocation::block_compare_equal(
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
return reference::device::BlockCompareEqual<float_ue4m3_t>(
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_B),
|
||||
capacity);
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
return reference::device::BlockCompareEqual<float_ue8m0_t>(
|
||||
reinterpret_cast<float_ue8m0_t const *>(ptr_A),
|
||||
reinterpret_cast<float_ue8m0_t const *>(ptr_B),
|
||||
capacity);
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
return reference::device::BlockCompareEqual<float_e2m3_t>(
|
||||
reinterpret_cast<float_e2m3_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e2m3_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
return reference::device::BlockCompareEqual<float_e3m2_t>(
|
||||
reinterpret_cast<float_e3m2_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e3m2_t const *>(ptr_B),
|
||||
capacity);
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
return reference::device::BlockCompareEqual<float_e2m1_t>(
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -1684,6 +1886,46 @@ bool DeviceAllocation::block_compare_relatively_equal(
|
||||
capacity,
|
||||
static_cast<float_e5m2_t>(epsilon),
|
||||
static_cast<float_e5m2_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_ue4m3_t>(
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_B),
|
||||
capacity,
|
||||
static_cast<float_ue4m3_t>(epsilon),
|
||||
static_cast<float_ue4m3_t>(nonzero_floor));
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_ue8m0_t>(
|
||||
reinterpret_cast<float_ue8m0_t const *>(ptr_A),
|
||||
reinterpret_cast<float_ue8m0_t const *>(ptr_B),
|
||||
capacity,
|
||||
static_cast<float_ue8m0_t>(epsilon),
|
||||
static_cast<float_ue8m0_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_e2m3_t>(
|
||||
reinterpret_cast<float_e2m3_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e2m3_t const *>(ptr_B),
|
||||
capacity,
|
||||
static_cast<float_e2m3_t>(epsilon),
|
||||
static_cast<float_e2m3_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_e3m2_t>(
|
||||
reinterpret_cast<float_e3m2_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e3m2_t const *>(ptr_B),
|
||||
capacity,
|
||||
static_cast<float_e3m2_t>(epsilon),
|
||||
static_cast<float_e3m2_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_e2m1_t>(
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_B),
|
||||
capacity,
|
||||
static_cast<float_e2m1_t>(epsilon),
|
||||
static_cast<float_e2m1_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareRelativelyEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -2026,6 +2268,27 @@ void DeviceAllocation::write_tensor_csv(
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
write_tensor_csv_static_type<float_e5m2_t>(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
write_tensor_csv_static_type<float_ue4m3_t>(out, *this);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
write_tensor_csv_static_type<float_e2m3_t>(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
write_tensor_csv_static_type<float_e3m2_t>(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
write_tensor_csv_static_type<float_e2m1_t>(out, *this);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
write_tensor_csv_static_type<float_ue8m0_t>(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
write_tensor_csv_static_type<half_t>(out, *this);
|
||||
break;
|
||||
@ -2193,6 +2456,27 @@ void DeviceAllocation::fill_device(double val = 0.0) {
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
tensor_fill<float_e5m2_t>(*this, static_cast<float_e5m2_t>(val));
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
tensor_fill<float_ue4m3_t>(*this, static_cast<float_ue4m3_t>(val));
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
tensor_fill<float_ue8m0_t>(*this, static_cast<float_ue8m0_t>(val));
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
tensor_fill<float_e2m3_t>(*this, static_cast<float_e2m3_t>(val));
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
tensor_fill<float_e3m2_t>(*this, static_cast<float_e3m2_t>(val));
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
tensor_fill<float_e2m1_t>(*this, static_cast<float_e2m1_t>(val));
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
tensor_fill<half_t>(*this, static_cast<half_t>(val));
|
||||
break;
|
||||
@ -2288,6 +2572,47 @@ void DeviceAllocation::fill_host(double val = 0.0) {
|
||||
std::vector<uint8_t> host_data(bytes());
|
||||
|
||||
switch (this->type()) {
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::host::BlockFill<float_ue4m3_t>(
|
||||
reinterpret_cast<float_ue4m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_ue4m3_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
cutlass::reference::host::BlockFill<float_ue8m0_t>(
|
||||
reinterpret_cast<float_ue8m0_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_ue8m0_t>(val)
|
||||
);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::host::BlockFill<float_e2m3_t>(
|
||||
reinterpret_cast<float_e2m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_e2m3_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
cutlass::reference::host::BlockFill<float_e3m2_t>(
|
||||
reinterpret_cast<float_e3m2_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_e3m2_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
cutlass::reference::host::BlockFill<float_e2m1_t>(
|
||||
reinterpret_cast<float_e2m1_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_e2m1_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE4M3:
|
||||
cutlass::reference::host::BlockFill<float_e4m3_t>(
|
||||
reinterpret_cast<float_e4m3_t *>(host_data.data()),
|
||||
|
||||
@ -104,6 +104,25 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
data_distribution.set_uniform(-1, 1, 0);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
data_distribution.set_uniform(-2, 2, 0);
|
||||
break;
|
||||
case library::NumericTypeID::kFE3M2:
|
||||
data_distribution.set_uniform(-2, 2, 0);
|
||||
break;
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
data_distribution.set_uniform(-2, 2, 0);
|
||||
break;
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
data_distribution.set_uniform(1, 4, 0);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
data_distribution.set_uniform(1, 4, 0);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
data_distribution.set_uniform(-3, 3, 0);
|
||||
break;
|
||||
|
||||
@ -76,6 +76,8 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
||||
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
||||
{ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"},
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"},
|
||||
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
|
||||
},
|
||||
@ -172,6 +174,38 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
this->k = 1024;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_m = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_n = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_k, "cluster_k", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_k = 1;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_m_fallback = 0;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_n_fallback = 0;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) {
|
||||
// default value
|
||||
this->cluster_k_fallback = 0;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) {
|
||||
// default value
|
||||
this->use_pdl = false;
|
||||
@ -192,6 +226,18 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
this->split_k_slices = 1;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) {
|
||||
// default value
|
||||
this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic;
|
||||
}
|
||||
|
||||
if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) {
|
||||
// default value
|
||||
this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic;
|
||||
}
|
||||
|
||||
|
||||
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
|
||||
// default value
|
||||
this->batch_count = 1;
|
||||
@ -338,6 +384,15 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "n", problem_space, n);
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
||||
set_argument(result, "batch_count", problem_space, batch_count);
|
||||
@ -345,6 +400,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
||||
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
|
||||
|
||||
|
||||
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
|
||||
set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b));
|
||||
|
||||
|
||||
set_argument(result, "alpha", problem_space,
|
||||
library::lexical_cast(alpha, operation_desc.element_epilogue));
|
||||
|
||||
@ -388,6 +448,14 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
gemm_workspace_[i].configuration.problem_size.m() = int(problem_.m);
|
||||
gemm_workspace_[i].configuration.problem_size.n() = int(problem_.n);
|
||||
gemm_workspace_[i].configuration.problem_size.k() = int(problem_.k);
|
||||
|
||||
gemm_workspace_[i].configuration.cluster_shape.m() = int(problem_.cluster_m);
|
||||
gemm_workspace_[i].configuration.cluster_shape.n() = int(problem_.cluster_n);
|
||||
gemm_workspace_[i].configuration.cluster_shape.k() = int(problem_.cluster_k);
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback);
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback);
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback);
|
||||
|
||||
gemm_workspace_[i].configuration.lda = problem_.lda;
|
||||
gemm_workspace_[i].configuration.ldb = problem_.ldb;
|
||||
gemm_workspace_[i].configuration.ldc = problem_.ldc;
|
||||
@ -423,6 +491,15 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
gemm_workspace_[i].arguments.swizzle_size = problem_.swizzle_size;
|
||||
gemm_workspace_[i].arguments.raster_order = problem_.raster_order;
|
||||
gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
|
||||
gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
|
||||
gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices;
|
||||
|
||||
|
||||
gemm_workspace_[i].arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a;
|
||||
gemm_workspace_[i].arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b;
|
||||
|
||||
|
||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||
if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) {
|
||||
return can_implement;
|
||||
@ -621,6 +698,9 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||
// NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels
|
||||
gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)};
|
||||
gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
|
||||
gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
|
||||
gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices;
|
||||
gemm_workspace_[i].arguments.batch_count = problem_.batch_count;
|
||||
gemm_workspace_[i].arguments.lda = problem_.lda;
|
||||
gemm_workspace_[i].arguments.ldb = problem_.ldb;
|
||||
@ -857,12 +937,32 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
}
|
||||
#endif // #if CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.front().arguments.runtime_input_datatype_a;
|
||||
cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.front().arguments.runtime_input_datatype_b;
|
||||
|
||||
bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic;
|
||||
bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic;
|
||||
|
||||
assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static.");
|
||||
|
||||
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
|
||||
cutlass::library::NumericTypeID element_A = gemm_desc.A.element;
|
||||
cutlass::library::NumericTypeID element_B = gemm_desc.B.element;
|
||||
|
||||
if (is_runtime_datatype_a) {
|
||||
element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a);
|
||||
}
|
||||
|
||||
if (is_runtime_datatype_b) {
|
||||
element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b);
|
||||
}
|
||||
|
||||
|
||||
bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B);
|
||||
|
||||
// Update disposition to worst case verification outcome among all
|
||||
@ -1087,6 +1187,14 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
gemm_workspace_[i].configuration.problem_size.m(),
|
||||
gemm_workspace_[i].configuration.problem_size.n(),
|
||||
gemm_workspace_[i].configuration.problem_size.k(),
|
||||
|
||||
gemm_workspace_[i].configuration.cluster_shape.m(),
|
||||
gemm_workspace_[i].configuration.cluster_shape.n(),
|
||||
gemm_workspace_[i].configuration.cluster_shape.k(),
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.m(),
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.n(),
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.k(),
|
||||
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
|
||||
|
||||
@ -91,6 +91,11 @@ OperationProfiler::OperationProfiler(
|
||||
{ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"},
|
||||
{ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"},
|
||||
{ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"},
|
||||
|
||||
{ArgumentTypeID::kInteger, {"cluster_m_fallback", "cluster-shape-fallback::m"}, "Fallback Cluster shape in the M dimension"},
|
||||
{ArgumentTypeID::kInteger, {"cluster_n_fallback", "cluster-shape-fallback::n"}, "Fallback Cluster shape in the N dimension"},
|
||||
{ArgumentTypeID::kInteger, {"cluster_k_fallback", "cluster-shape-fallback::k"}, "Fallback Cluster shape in the K dimension"},
|
||||
|
||||
{ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"},
|
||||
{ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"},
|
||||
{ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"},
|
||||
@ -174,6 +179,11 @@ bool OperationProfiler::satisfies(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool dynamic_cluster = int64_t(op_desc.tile_description.cluster_shape.m()) == 0 ||
|
||||
int64_t(op_desc.tile_description.cluster_shape.n()) == 0 ||
|
||||
int64_t(op_desc.tile_description.cluster_shape.k()) == 0;
|
||||
|
||||
int64_t int_value;
|
||||
|
||||
if (arg_as_int(int_value, "inst_m", problem_space, problem)) {
|
||||
@ -212,6 +222,7 @@ bool OperationProfiler::satisfies(
|
||||
}
|
||||
}
|
||||
|
||||
if (!dynamic_cluster) {
|
||||
if (arg_as_int(int_value, "cluster_m", problem_space, problem)) {
|
||||
if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) {
|
||||
return false;
|
||||
@ -230,6 +241,7 @@ bool OperationProfiler::satisfies(
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if (arg_as_int(int_value, "stages", problem_space, problem)) {
|
||||
if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) {
|
||||
return false;
|
||||
@ -296,6 +308,11 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind provider) {
|
||||
if (provider == library::OperationKind::kGemm) {
|
||||
out << "kGemm";
|
||||
}
|
||||
|
||||
else if (provider == library::OperationKind::kBlockScaledGemm) {
|
||||
out << "kBlockScaledGemm";
|
||||
}
|
||||
|
||||
else if (provider == library::OperationKind::kRankK) {
|
||||
out << "kRankK";
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <set>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -810,16 +811,27 @@ Options::Options(cutlass::CommandLine const &cmdline):
|
||||
}
|
||||
else if (cmdline.check_cmd_line_flag("kernels")) {
|
||||
cmdline.get_cmd_line_arguments("kernels", operation_names);
|
||||
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
|
||||
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
|
||||
}
|
||||
|
||||
if (cmdline.check_cmd_line_flag("kernels-file")) {
|
||||
std::string filename;
|
||||
cmdline.get_cmd_line_argument("kernels-file", filename, {});
|
||||
std::ifstream input(filename);
|
||||
if (!input.good()) {
|
||||
throw std::runtime_error("failed to open: " + filename);
|
||||
}
|
||||
for (std::string line; getline(input, line);) {
|
||||
operation_names.push_back(line);
|
||||
}
|
||||
}
|
||||
|
||||
if (cmdline.check_cmd_line_flag("ignore-kernels")) {
|
||||
cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names);
|
||||
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
|
||||
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
|
||||
}
|
||||
|
||||
profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match");
|
||||
profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled");
|
||||
|
||||
// Prevent launches on the device for anything other than CUTLASS operation
|
||||
// Allow verification only on host
|
||||
if (execution_mode == ExecutionMode::kTrace) {
|
||||
@ -856,6 +868,11 @@ void Options::print_usage(std::ostream &out) const {
|
||||
<< " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line
|
||||
<< " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n"
|
||||
|
||||
<< " --kernels-file=<filename> "
|
||||
<< " Same behavior as --kernels, but kernel names are specified in a file" << end_of_line
|
||||
<< " with one kernel on each line. Set of profiled kernels is the union of kernels specified" << end_of_line
|
||||
<< " here and those specified in `kernels`.\n\n"
|
||||
|
||||
<< " --ignore-kernels=<string_list> "
|
||||
<< " Excludes kernels whose names match anything in this list.\n\n"
|
||||
;
|
||||
|
||||
@ -879,6 +879,32 @@ bool arg_as_NumericTypeID(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RuntimeDatatype(
|
||||
library::RuntimeDatatype &runtime_datatype,
|
||||
KernelArgument::Value const *value_ptr) {
|
||||
|
||||
if (value_ptr->not_null) {
|
||||
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
|
||||
|
||||
runtime_datatype = library::from_string<library::RuntimeDatatype>(
|
||||
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
|
||||
if (runtime_datatype == library::RuntimeDatatype::kInvalid) {
|
||||
throw std::runtime_error(
|
||||
"arg_as_RuntimeDatatype() - illegal cast.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error(
|
||||
"arg_as_RuntimeDatatype() - illegal cast.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RasterOrder(
|
||||
library::RasterOrder &raster_order,
|
||||
@ -945,6 +971,21 @@ bool arg_as_LayoutTypeID(
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_RuntimeDatatype(
|
||||
library::RuntimeDatatype &runtime_datatype,
|
||||
char const *name,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
size_t idx = problem_space.argument_index(name);
|
||||
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
||||
|
||||
return arg_as_RuntimeDatatype(runtime_datatype, value_ptr);
|
||||
}
|
||||
|
||||
|
||||
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
||||
bool arg_as_LayoutTypeID(
|
||||
library::LayoutTypeID &layout_type,
|
||||
|
||||
@ -922,7 +922,7 @@ __global__ void Conv3dWgrad(
|
||||
filter_s = problem_size.S - 1 - filter_s;
|
||||
}
|
||||
|
||||
int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d;
|
||||
int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
||||
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
||||
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
||||
|
||||
|
||||
@ -352,7 +352,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, xor_add<ComputeType>>(
|
||||
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
||||
}
|
||||
|
||||
@ -367,7 +367,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, xor_add<ComputeType>>(
|
||||
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
@ -389,7 +389,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, and_add<ComputeType>>(
|
||||
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
||||
}
|
||||
|
||||
@ -404,7 +404,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
ScalarType, ComputeType, and_add<ComputeType>>(
|
||||
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
||||
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
||||
}
|
||||
};
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
#include "cutlass/relatively_equal.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/pointer.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -59,10 +60,20 @@ struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
//
|
||||
// Gett Mainloop Parameters
|
||||
//
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_ // (N, K, L)
|
||||
|
||||
, class TensorSfA_ = TensorA_,
|
||||
class TensorSfB_ = TensorB_
|
||||
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
@ -79,23 +90,105 @@ struct GettMainloopParams {
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
|
||||
|
||||
using TensorSfA = TensorSfA_;
|
||||
using TensorSfB = TensorSfB_;
|
||||
using EngineSfA = typename TensorSfA::engine_type;
|
||||
using LayoutSfA = typename TensorSfA::layout_type;
|
||||
using EngineSfB = typename TensorSfB::engine_type;
|
||||
using LayoutSfB = typename TensorSfB::layout_type;
|
||||
TensorSfA_ SfA{};
|
||||
TensorSfB_ SfB{};
|
||||
|
||||
|
||||
GettMainloopParams() {}
|
||||
|
||||
GettMainloopParams(TensorA tensor_A, TensorB tensor_B)
|
||||
: A(tensor_A), B(tensor_B) {}
|
||||
|
||||
|
||||
GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
|
||||
: A(tensor_A), SfA(tensor_SfA),
|
||||
B(tensor_B), SfB(tensor_SfB) {}
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorSfA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorSfB_ // (N, K, L)
|
||||
>
|
||||
struct GettBlockScalingMainloopParams : public GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_> {
|
||||
using Base = GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_>;
|
||||
using ElementAccumulator = typename Base::ElementAccumulator;
|
||||
using TensorA = typename Base::TensorA;
|
||||
using TensorB = typename Base::TensorB;
|
||||
using EngineA = typename Base::EngineA;
|
||||
using LayoutA = typename Base::LayoutA;
|
||||
using EngineB = typename Base::EngineB;
|
||||
using LayoutB = typename Base::LayoutB;
|
||||
ComplexTransform transform_A = Base::transform_A;
|
||||
ComplexTransform transform_B = Base::transform_B;
|
||||
|
||||
using TensorSfA = typename Base::TensorSfA;
|
||||
using TensorSfB = typename Base::TensorSfB;
|
||||
using EngineSfA = typename Base::EngineSfA;
|
||||
using LayoutSfA = typename Base::LayoutSfA;
|
||||
using EngineSfB = typename Base::EngineSfB;
|
||||
using LayoutSfB = typename Base::LayoutSfB;
|
||||
|
||||
GettBlockScalingMainloopParams() {}
|
||||
|
||||
GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
|
||||
: Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {}
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class SfStrategy {
|
||||
None = 0,
|
||||
SfDGen = 1
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
//
|
||||
// Gett Epilogue Parameters
|
||||
//
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
|
||||
class TensorAux_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, N, L)
|
||||
class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class TensorSFD_ = TensorD_,
|
||||
class SFD_VectorSize_ = cute::Int<0>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
,
|
||||
SfStrategy SfGenStrategy_ = SfStrategy::None
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
@ -108,6 +201,8 @@ struct GettEpilogueParams {
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using TensorSFD = TensorSFD_;
|
||||
using SFD_VectorSize = SFD_VectorSize_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
@ -115,7 +210,11 @@ struct GettEpilogueParams {
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
using EngineSfD = typename TensorSFD::engine_type;
|
||||
using LayoutSfD = typename TensorSFD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
static constexpr SfStrategy SfGenStrategy = SfGenStrategy_;
|
||||
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
@ -125,7 +224,8 @@ struct GettEpilogueParams {
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
TensorSFD SfD{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
@ -137,8 +237,250 @@ struct GettEpilogueParams {
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
GettEpilogueParams() {}
|
||||
|
||||
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
|
||||
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {}
|
||||
|
||||
|
||||
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
|
||||
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {}
|
||||
|
||||
|
||||
GettEpilogueParams(
|
||||
ElementScalar alpha, ElementScalar beta,
|
||||
TensorC tensor_C, TensorD tensor_D,
|
||||
VectorBias bias, TensorAux tensor_aux,
|
||||
VectorAlpha vector_alpha, VectorBeta vector_beta)
|
||||
: alpha(alpha), beta(beta),
|
||||
C(tensor_C), D(tensor_D),
|
||||
Bias(bias), Aux(tensor_aux),
|
||||
Valpha(vector_alpha), Vbeta(vector_beta) {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_,
|
||||
class TensorD_,
|
||||
class TensorSfD_ = TensorD_,
|
||||
class SFD_VectorSize_ = cute::Int<0>,
|
||||
SfStrategy SfGenStrategy_ = SfStrategy::None
|
||||
>
|
||||
struct GettBlockScalingEpilogueParams : public GettEpilogueParams<
|
||||
ElementScalar_, // ElementScalar
|
||||
ElementScalar_, // ElementScalingFactor
|
||||
ElementAccumulator_, // ElementAccumulator
|
||||
ElementCompute_, // ElementCompute
|
||||
TensorC_, // TensorC (M, N, L)
|
||||
TensorD_, // TensorD (M, N, L)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
|
||||
cutlass::epilogue::thread::Identity<ElementCompute_>, //
|
||||
TensorSfD_, // TensorSfD
|
||||
SFD_VectorSize_, // SFD_VectorSize
|
||||
cutlass::plus<ElementCompute_>, // class BiasBinaryOp_ =
|
||||
false, //PerColumnBias_
|
||||
SfGenStrategy_ // SfGenStrategy
|
||||
> {
|
||||
using Base = GettEpilogueParams<
|
||||
ElementScalar_, // ElementScalar
|
||||
ElementScalar_, // ElementScalingFactor
|
||||
ElementAccumulator_, // ElementAccumulator
|
||||
ElementCompute_, // ElementCompute
|
||||
TensorC_, // TensorC (M, N, L)
|
||||
TensorD_, // TensorD (M, N, L)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
|
||||
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
|
||||
cutlass::epilogue::thread::Identity<ElementCompute_>, //
|
||||
TensorSfD_, // TensorSfD
|
||||
SFD_VectorSize_, // SFD_VectorSize
|
||||
cutlass::plus<ElementCompute_>, // BiasBinaryOp
|
||||
false, // PerColumnBias
|
||||
SfGenStrategy_ // SfGenStrategy
|
||||
>;
|
||||
using ElementScalar = typename Base::ElementScalar;
|
||||
using ElementScalingFactor = typename Base::ElementScalingFactor;
|
||||
using ElementAccumulator = typename Base::ElementAccumulator;
|
||||
using ElementCompute = typename Base::ElementCompute;
|
||||
using TensorC = typename Base::TensorC;
|
||||
using TensorD = typename Base::TensorD;
|
||||
using TensorAux = typename Base::TensorAux;
|
||||
using VectorBias = typename Base::VectorBias;
|
||||
using VectorAlpha = typename Base::VectorAlpha;
|
||||
using VectorBeta = typename Base::VectorBeta;
|
||||
using TensorSFD = typename Base::TensorSFD;
|
||||
using SFD_VectorSize = typename Base::SFD_VectorSize;
|
||||
using ActivationFunctor = typename Base::ActivationFunctor;
|
||||
using BiasBinaryOp = typename Base::BiasBinaryOp;
|
||||
|
||||
using EngineC = typename Base::EngineC;
|
||||
using LayoutC = typename Base::LayoutC;
|
||||
using EngineD = typename Base::EngineD;
|
||||
using LayoutD = typename Base::LayoutD;
|
||||
using EngineSfD = typename Base::EngineSfD;
|
||||
using LayoutSfD = typename Base::LayoutSfD;
|
||||
static constexpr bool PerColumnBias = Base::PerColumnBias;
|
||||
static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy;
|
||||
|
||||
GettBlockScalingEpilogueParams() {}
|
||||
|
||||
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
|
||||
: Base(alpha, beta, tensor_C, tensor_D) {}
|
||||
|
||||
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD)
|
||||
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {}
|
||||
|
||||
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
|
||||
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////
|
||||
//
|
||||
// Generic Gett 3x Implementation
|
||||
//
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <int kVectorSize, class EpilogueParams, class TensorD, class TensorSFD, class ElementCompute, int kBlockM, int kBlockN>
|
||||
void compute_1d_scaling_factor_and_quantized_output(
|
||||
EpilogueParams const& epilogue_params,
|
||||
TensorD &tensor_D,
|
||||
TensorSFD &tensor_SfD,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementCompute (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
using ElementD = typename ElementTraits<typename EpilogueParams::EngineD::value_type>::type;
|
||||
using ElementSfD = typename ElementTraits<typename EpilogueParams::EngineSfD::value_type>::type;
|
||||
|
||||
int const M = cute::size<0>(tensor_D.layout());
|
||||
int const N = cute::size<1>(tensor_D.layout());
|
||||
int const L = cute::size<2>(tensor_D.layout());
|
||||
|
||||
auto mul = cutlass::multiplies<ElementCompute>{};
|
||||
auto div = divides<ElementCompute>{};
|
||||
// Get FP max
|
||||
ElementCompute fp_max = ElementCompute(std::numeric_limits<ElementD>::max());
|
||||
float scale_down_factor = div(1.0f, fp_max);
|
||||
// Get st' = st / FP max
|
||||
ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor);
|
||||
|
||||
absolute_value_op<ElementCompute> abs_op;
|
||||
maximum_with_nan_propogation<ElementCompute> max_op;
|
||||
|
||||
if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) {
|
||||
// MN major output
|
||||
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
|
||||
// Col major output
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
|
||||
int64_t col = n + n_b;
|
||||
|
||||
/// Step1: get max across a vector
|
||||
ElementCompute accum_max = ElementCompute(0);
|
||||
for (int v = 0; v < kVectorSize; v++) {
|
||||
int accum_row = v_b * kVectorSize + v;
|
||||
int64_t output_row = accum_row + m;
|
||||
if (output_row < M && col < N) {
|
||||
accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Step2: Compute Scale
|
||||
ElementCompute pvscale = mul(accum_max, st_scaled_down);
|
||||
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
|
||||
// Store the Scaling Factors
|
||||
int64_t sf_row = m + kVectorSize * v_b;
|
||||
if (sf_row < M && col < N) {
|
||||
tensor_SfD(sf_row, col, l) = qpvscale;
|
||||
}
|
||||
|
||||
/// Step3: Compute quantized output values
|
||||
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
|
||||
// Get float reciprocal
|
||||
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
|
||||
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
|
||||
// Map INF to fp32::max
|
||||
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
|
||||
// Store the intermediate_accum
|
||||
for (int v = 0; v < kVectorSize; v++) {
|
||||
int accum_row = v_b * kVectorSize + v;
|
||||
int64_t output_row = accum_row + m;
|
||||
if (output_row < M && col < N) {
|
||||
acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize);
|
||||
// row major output
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
|
||||
int64_t row = m + m_b;
|
||||
|
||||
/// Step1: get max across a vector
|
||||
ElementCompute accum_max = ElementCompute(0);
|
||||
for (int v = 0; v < kVectorSize; v++) {
|
||||
int accum_col = v_b * kVectorSize + v;
|
||||
int64_t output_col = accum_col + n;
|
||||
if (row < M && output_col < N) {
|
||||
accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Step2: Compute Scale
|
||||
ElementCompute pvscale = mul(accum_max, st_scaled_down);
|
||||
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
|
||||
// Store the Scaling Factors
|
||||
int64_t sf_col = n + kVectorSize * v_b;
|
||||
|
||||
if (row < M && sf_col < N) {
|
||||
tensor_SfD(row, sf_col, l) = qpvscale;
|
||||
}
|
||||
|
||||
/// Step3: Compute quantized output values
|
||||
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
|
||||
// Get float reciprocal
|
||||
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
|
||||
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
|
||||
// Map INF to fp32::max
|
||||
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
|
||||
// Store the intermediate_accum
|
||||
for (int v = 0; v < kVectorSize; v++) {
|
||||
int accum_col = v_b * kVectorSize + v;
|
||||
int64_t output_col = accum_col + n;
|
||||
if (row < M && output_col < N) {
|
||||
acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel
|
||||
@ -188,6 +530,11 @@ void gett_mainloop(
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
|
||||
|
||||
using ElementSFA = typename ElementTraits<typename MainloopParams::EngineSfA::value_type>::type;
|
||||
using ElementSFB = typename ElementTraits<typename MainloopParams::EngineSfB::value_type>::type;
|
||||
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
@ -207,6 +554,14 @@ void gett_mainloop(
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
|
||||
|
||||
if constexpr (not cute::is_same_v<ElementSFA, ElementA>){
|
||||
// Load SFA
|
||||
auto sfa = static_cast<ElementAccumulator>(mainloop_params.SfA(m + m_b, k, l));
|
||||
a_frag[m_b] *= sfa;
|
||||
}
|
||||
|
||||
|
||||
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
|
||||
a_frag[m_b] = conj(a_frag[m_b]);
|
||||
}
|
||||
@ -222,6 +577,14 @@ void gett_mainloop(
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
|
||||
|
||||
if constexpr (not cute::is_same_v<ElementSFB, ElementB>){
|
||||
// Load SFB
|
||||
auto sfb = static_cast<ElementAccumulator>(mainloop_params.SfB(n + n_b, k, l));
|
||||
b_frag[n_b] *= sfb;
|
||||
}
|
||||
|
||||
|
||||
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
|
||||
b_frag[n_b] = conj(b_frag[n_b]);
|
||||
}
|
||||
@ -259,6 +622,7 @@ void gett_epilogue(
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementSfD = typename EpilogueParams::TensorSFD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
@ -267,6 +631,8 @@ void gett_epilogue(
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy;
|
||||
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
@ -412,6 +778,17 @@ void gett_epilogue(
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
|
||||
if constexpr (
|
||||
SfGenStrategy == SfStrategy::SfDGen
|
||||
) {
|
||||
// 1d scale factor generation
|
||||
constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{};
|
||||
if (epilogue_params.SfD.data() != nullptr) {
|
||||
compute_1d_scaling_factor_and_quantized_output<kVectorSize>(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum);
|
||||
}
|
||||
}
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
|
||||
Reference in New Issue
Block a user