cutlass 3.9 update (#2255)
* cutlass 3.9 update * rebase * fixes out of shared memory for blockwise Blackwell * doc format * fix issue 2253 * disable host ref by default * fix sm120 smem capacity --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -233,6 +233,10 @@ cutlass_add_cutlass_library(
|
||||
src/reference/gemm_f6_f8_f32.cu
|
||||
src/reference/gemm_f8_f4_f32.cu
|
||||
src/reference/gemm_f8_f6_f32.cu
|
||||
|
||||
src/reference/blockwise_gemm_fp8_fp16out.cu
|
||||
src/reference/blockwise_gemm_fp8_fp32out.cu
|
||||
src/reference/blockwise_gemm_fp8_bf16out.cu
|
||||
|
||||
src/reference/gemm_s8_s8_s32.cu
|
||||
src/reference/gemm_u8_u8_s32.cu
|
||||
|
||||
@ -313,10 +313,16 @@ struct BlockScaleDescription {
|
||||
TensorDescription SFD;
|
||||
|
||||
/// Describes the input ScaleFactor VectorSize
|
||||
int SFVecSize;
|
||||
int SFMVecSize;
|
||||
int SFNVecSize;
|
||||
int SFKVecSize;
|
||||
|
||||
/// Describes the Output ScaleFactor VectorSize
|
||||
int EpilogueSFVecSize;
|
||||
|
||||
/// Describes the underlying kind of scaling:
|
||||
/// Tensor Core supported (BlockScaled) or manual scaling (Blockwise)
|
||||
OperationKind kind;
|
||||
};
|
||||
|
||||
struct GroupedGemmDescription : public OperationDescription {
|
||||
@ -418,6 +424,96 @@ struct BlockScaledGemmDescription : public OperationDescription {
|
||||
transform_B(transform_B) {}
|
||||
};
|
||||
|
||||
/// Description of all GEMM computations
|
||||
struct BlockwiseGemmDescription : public OperationDescription {
|
||||
|
||||
/// Indicates the kind of GEMM performed
|
||||
GemmKind gemm_kind;
|
||||
|
||||
/// Describes the A operand
|
||||
TensorDescription A;
|
||||
|
||||
/// Describes the B operand
|
||||
TensorDescription B;
|
||||
|
||||
/// Describes the source matrix
|
||||
TensorDescription C;
|
||||
|
||||
/// Describes the destination matrix
|
||||
TensorDescription D;
|
||||
|
||||
/// Describes the SFA operand
|
||||
TensorDescription SFA;
|
||||
|
||||
/// Describes the SFB operand
|
||||
TensorDescription SFB;
|
||||
|
||||
/// Describes the data type of the scalars passed to the epilogue
|
||||
NumericTypeID element_epilogue;
|
||||
|
||||
/// Describes the structure of parallel reductions
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
/// Transformation on A operand
|
||||
ComplexTransform transform_A;
|
||||
|
||||
/// Transformation on B operand
|
||||
ComplexTransform transform_B;
|
||||
|
||||
/// Describes the input ScaleFactor VectorSize
|
||||
int SFMVecSize;
|
||||
int SFNVecSize;
|
||||
int SFKVecSize;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
BlockwiseGemmDescription(
|
||||
GemmKind gemm_kind = GemmKind::kGemm,
|
||||
TensorDescription const& A = TensorDescription(),
|
||||
TensorDescription const& B = TensorDescription(),
|
||||
TensorDescription const& C = TensorDescription(),
|
||||
TensorDescription const& D = TensorDescription(),
|
||||
NumericTypeID element_epilogue = NumericTypeID::kInvalid,
|
||||
SplitKMode split_k_mode = SplitKMode::kNone,
|
||||
ComplexTransform transform_A = ComplexTransform::kNone,
|
||||
ComplexTransform transform_B = ComplexTransform::kNone
|
||||
):
|
||||
gemm_kind(gemm_kind),
|
||||
A(A),
|
||||
B(B),
|
||||
C(C),
|
||||
D(D),
|
||||
element_epilogue(element_epilogue),
|
||||
split_k_mode(split_k_mode),
|
||||
transform_A(transform_A),
|
||||
transform_B(transform_B) {}
|
||||
|
||||
BlockwiseGemmDescription(
|
||||
OperationDescription op_desc,
|
||||
GemmKind gemm_kind,
|
||||
TensorDescription const& A,
|
||||
TensorDescription const& B,
|
||||
TensorDescription const& C,
|
||||
TensorDescription const& D,
|
||||
NumericTypeID element_epilogue,
|
||||
SplitKMode split_k_mode,
|
||||
ComplexTransform transform_A,
|
||||
ComplexTransform transform_B
|
||||
):
|
||||
OperationDescription(op_desc),
|
||||
gemm_kind(gemm_kind),
|
||||
A(A),
|
||||
B(B),
|
||||
C(C),
|
||||
D(D),
|
||||
element_epilogue(element_epilogue),
|
||||
split_k_mode(split_k_mode),
|
||||
transform_A(transform_A),
|
||||
transform_B(transform_B) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Description for structured sparse GEMMs.
|
||||
|
||||
@ -121,6 +121,13 @@ public:
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const = 0;
|
||||
|
||||
// Set arguments that should only be set once before verifying or profiling the kernel.
|
||||
// This should encompass any expensive operations that don't vary from run to run
|
||||
// (e.g., max_active_clusters).
|
||||
virtual Status initialize_with_arguments(void* arguments_ptr) const {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -389,6 +396,56 @@ struct BlockScaledGemmArguments {
|
||||
bool use_pdl{false};
|
||||
};
|
||||
|
||||
/// Blockwise GEMM
|
||||
//
|
||||
// OperationKind: kBlockwiseGemm
|
||||
// GemmKind: Universal
|
||||
|
||||
struct BlockwiseGemmArguments {
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
gemm::GemmCoord problem_size{};
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
int batch_count{1};
|
||||
|
||||
void const *A{nullptr};
|
||||
void const *B{nullptr};
|
||||
void const *SFA{nullptr};
|
||||
void const *SFB{nullptr};
|
||||
void const *C{nullptr};
|
||||
void *D{nullptr};
|
||||
|
||||
void const *alpha{nullptr};
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
int64_t ldd{0};
|
||||
|
||||
int64_t batch_stride_A{0};
|
||||
int64_t batch_stride_B{0};
|
||||
int64_t batch_stride_C{0};
|
||||
int64_t batch_stride_D{0};
|
||||
|
||||
int sf_m_vec_size{0};
|
||||
int sf_n_vec_size{0};
|
||||
int sf_k_vec_size{0};
|
||||
|
||||
// Needed for some 3.x kernels
|
||||
int sm_count{0};
|
||||
library::RasterOrder raster_order{};
|
||||
int swizzle_size{1};
|
||||
int split_k_slices{1};
|
||||
|
||||
library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic};
|
||||
library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic};
|
||||
|
||||
bool use_pdl{false};
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -521,6 +578,8 @@ struct GemmGroupedArguments {
|
||||
|
||||
// these should really be in the configuration but staying consistent with GEMM
|
||||
int sm_count{0};
|
||||
int max_active_clusters{0};
|
||||
|
||||
// The user is responsible for allocating storage for problem sizes.
|
||||
// Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we
|
||||
// unfortunately need to have both options in this struct, and the
|
||||
@ -536,6 +595,12 @@ struct GroupedGemmBlockScaledArguments : GemmGroupedArguments {
|
||||
void* norm_constant{nullptr};
|
||||
};
|
||||
|
||||
struct GroupedGemmBlockwiseArguments : GemmGroupedArguments {
|
||||
void* SFA{nullptr};
|
||||
void* SFB{nullptr};
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// OperationKind: kSparseGemm
|
||||
|
||||
@ -427,6 +427,183 @@ using BlockScaledGemmOperationFunctionalMap = std::unordered_map<
|
||||
BlockScaledGemmFunctionalKeyHasher
|
||||
>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Data Structures for Blockwise Gemm Functional Maps
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tuple uniquely identifying Gemm functional behavior
|
||||
struct BlockwiseGemmFunctionalKey {
|
||||
|
||||
Provider provider;
|
||||
GemmKind gemm_kind;
|
||||
OperationKind kind;
|
||||
NumericTypeID element_compute;
|
||||
NumericTypeID element_scalar;
|
||||
NumericTypeID element_A;
|
||||
LayoutTypeID layout_A;
|
||||
NumericTypeID element_SFA;
|
||||
NumericTypeID element_B;
|
||||
LayoutTypeID layout_B;
|
||||
NumericTypeID element_SFB;
|
||||
NumericTypeID element_C;
|
||||
LayoutTypeID layout_C;
|
||||
NumericTypeID element_D;
|
||||
LayoutTypeID layout_D;
|
||||
int SFMVecSize;
|
||||
int SFNVecSize;
|
||||
int SFKVecSize;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
inline
|
||||
BlockwiseGemmFunctionalKey(
|
||||
Provider provider,
|
||||
GemmKind gemm_kind = GemmKind::kGemm,
|
||||
OperationKind kind = OperationKind::kBlockwiseGemm,
|
||||
NumericTypeID element_compute = NumericTypeID::kF32,
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32,
|
||||
NumericTypeID element_A = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_A = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_SFA = NumericTypeID::kF16,
|
||||
NumericTypeID element_B = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_B = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_SFB = NumericTypeID::kF16,
|
||||
NumericTypeID element_C = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_C = LayoutTypeID::kColumnMajor,
|
||||
NumericTypeID element_D = NumericTypeID::kF16,
|
||||
LayoutTypeID layout_D = LayoutTypeID::kColumnMajor,
|
||||
int sfm_vec_size = 32,
|
||||
int sfn_vec_size = 32,
|
||||
int sfk_vec_size = 32
|
||||
):
|
||||
provider(provider),
|
||||
gemm_kind(gemm_kind),
|
||||
kind(kind),
|
||||
element_compute(element_compute),
|
||||
element_scalar(element_scalar),
|
||||
element_A(element_A),
|
||||
layout_A(layout_A),
|
||||
element_SFA(element_SFA),
|
||||
element_B(element_B),
|
||||
layout_B(layout_B),
|
||||
element_SFB(element_SFB),
|
||||
element_C(element_C),
|
||||
layout_C(layout_C),
|
||||
element_D(element_D),
|
||||
layout_D(layout_D),
|
||||
SFMVecSize(sfm_vec_size),
|
||||
SFNVecSize(sfn_vec_size),
|
||||
SFKVecSize(sfk_vec_size)
|
||||
{ }
|
||||
|
||||
inline
|
||||
bool operator==(BlockwiseGemmFunctionalKey const &rhs) const {
|
||||
return
|
||||
(provider == rhs.provider) &&
|
||||
(gemm_kind == rhs.gemm_kind) &&
|
||||
(kind == rhs.kind) &&
|
||||
(element_compute == rhs.element_compute) &&
|
||||
(element_scalar == rhs.element_scalar) &&
|
||||
(element_A == rhs.element_A) &&
|
||||
(layout_A == rhs.layout_A) &&
|
||||
(element_SFA == rhs.element_SFA) &&
|
||||
(element_B == rhs.element_B) &&
|
||||
(layout_B == rhs.layout_B) &&
|
||||
(element_SFB == rhs.element_SFB) &&
|
||||
(element_C == rhs.element_C) &&
|
||||
(layout_C == rhs.layout_C) &&
|
||||
(element_D == rhs.element_D) &&
|
||||
(layout_D == rhs.layout_D) &&
|
||||
(SFMVecSize == rhs.SFMVecSize) &&
|
||||
(SFNVecSize == rhs.SFNVecSize) &&
|
||||
(SFKVecSize == rhs.SFKVecSize);
|
||||
}
|
||||
|
||||
inline
|
||||
bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) {
|
||||
|
||||
out << "{\n"
|
||||
<< " provider: " << to_string(k.provider) << "\n"
|
||||
<< " gemm_kind: " << to_string(k.gemm_kind) << "\n"
|
||||
<< " kind: " << to_string(k.kind) << "\n"
|
||||
<< " element_compute: " << to_string(k.element_compute) << "\n"
|
||||
<< " element_scalar: " << to_string(k.element_scalar) << "\n"
|
||||
<< " element_A: " << to_string(k.element_A) << "\n"
|
||||
<< " layout_A: " << to_string(k.layout_A) << "\n"
|
||||
<< " element_SFA: " << to_string(k.element_SFA) << "\n"
|
||||
<< " element_B: " << to_string(k.element_B) << "\n"
|
||||
<< " layout_B: " << to_string(k.layout_B) << "\n"
|
||||
<< " element_SFB: " << to_string(k.element_SFB) << "\n"
|
||||
<< " element_C: " << to_string(k.element_C) << "\n"
|
||||
<< " layout_C: " << to_string(k.layout_C) << "\n"
|
||||
<< " element_D: " << to_string(k.element_D) << "\n"
|
||||
<< " layout_D: " << to_string(k.layout_D) << "\n"
|
||||
<< " SFMVecSize: " << k.SFMVecSize << "\n"
|
||||
<< " SFNVecSize: " << k.SFNVecSize << "\n"
|
||||
<< " SFKVecSize: " << k.SFKVecSize << "\n"
|
||||
<< "}";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Hash function for BlockwiseGemmFunctionalKeyHasher
|
||||
struct BlockwiseGemmFunctionalKeyHasher {
|
||||
using IntHash = std::hash<int>;
|
||||
|
||||
inline
|
||||
static size_t rotl(size_t key, int shl) {
|
||||
return (key << shl) | (key >> (sizeof(key)*8u - static_cast<size_t>(shl)));
|
||||
}
|
||||
|
||||
inline
|
||||
size_t operator()(BlockwiseGemmFunctionalKey const &key) const {
|
||||
IntHash hash;
|
||||
|
||||
return
|
||||
rotl(hash(int(key.provider)), 1) ^
|
||||
rotl(hash(int(key.gemm_kind)), 2) ^
|
||||
rotl(hash(int(key.kind)), 3) ^
|
||||
rotl(hash(int(key.element_compute)), 4) ^
|
||||
rotl(hash(int(key.element_scalar)), 5) ^
|
||||
rotl(hash(int(key.element_A)), 6) ^
|
||||
rotl(hash(int(key.layout_A)), 7) ^
|
||||
rotl(hash(int(key.element_SFA)), 8) ^
|
||||
rotl(hash(int(key.element_B)), 9) ^
|
||||
rotl(hash(int(key.layout_B)), 10) ^
|
||||
rotl(hash(int(key.element_SFB)), 11) ^
|
||||
rotl(hash(int(key.element_C)), 12) ^
|
||||
rotl(hash(int(key.layout_C)), 13) ^
|
||||
rotl(hash(int(key.element_D)), 14) ^
|
||||
rotl(hash(int(key.layout_D)), 15) ^
|
||||
rotl(hash(int(key.SFMVecSize)), 16) ^
|
||||
rotl(hash(int(key.SFNVecSize)), 17) ^
|
||||
rotl(hash(int(key.SFKVecSize)), 18)
|
||||
;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm
|
||||
using BlockwiseGemmOperationFunctionalMap = std::unordered_map<
|
||||
BlockwiseGemmFunctionalKey,
|
||||
GemmOperationVectorMap,
|
||||
BlockwiseGemmFunctionalKeyHasher
|
||||
>;
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Data Structures for Conv Functional Maps
|
||||
@ -697,6 +874,9 @@ public:
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations;
|
||||
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations;
|
||||
|
||||
/// Map of all operations of type kConv2d
|
||||
// provider (kCUTLASS, kReferenceHost, kReferenceDevice)
|
||||
ConvOperationFunctionalMap conv2d_operations;
|
||||
|
||||
@ -143,6 +143,7 @@ enum class Provider {
|
||||
enum class OperationKind {
|
||||
kGemm,
|
||||
kBlockScaledGemm,
|
||||
kBlockwiseGemm,
|
||||
kRankK,
|
||||
kRank2K,
|
||||
kTrmm,
|
||||
|
||||
429
tools/library/src/blockwise_gemm_operation_3x.hpp
Normal file
429
tools/library/src/blockwise_gemm_operation_3x.hpp
Normal file
@ -0,0 +1,429 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines operations for all GEMM operation kinds in CUTLASS Library.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "gemm_operation_3x.hpp"
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Operator_>
|
||||
class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
using ElementA = typename Operator::CollectiveMainloop::ElementA;
|
||||
using ElementSFA = typename Operator::ElementAccumulator;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::CollectiveMainloop::ElementB;
|
||||
using ElementSFB = typename Operator::ElementAccumulator;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementD = typename Operator::ElementD;
|
||||
using LayoutD = typename Operator::LayoutD;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
|
||||
|
||||
using CollectiveMainloop = typename Operator::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
||||
|
||||
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
||||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
|
||||
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
|
||||
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
|
||||
private:
|
||||
BlockwiseGemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
|
||||
description_.kind = OperationKind::kBlockwiseGemm;
|
||||
description_.SFA.element = NumericTypeMap<ElementSFA>::kId;
|
||||
description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ?
|
||||
LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
|
||||
description_.SFA.alignment = CollectiveMainloop::AlignmentSFA;
|
||||
description_.SFA.log_extent_range = 32;
|
||||
description_.SFA.log_stride_range = 32;
|
||||
|
||||
description_.SFB.element = NumericTypeMap<ElementSFB>::kId;
|
||||
description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ?
|
||||
LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor;
|
||||
description_.SFB.alignment = CollectiveMainloop::AlignmentSFA;
|
||||
description_.SFB.log_extent_range = 32;
|
||||
description_.SFB.log_stride_range = 32;
|
||||
|
||||
description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM;
|
||||
description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN;
|
||||
description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK;
|
||||
|
||||
description_.name = name;
|
||||
description_.provider = Provider::kCUTLASS;
|
||||
description_.gemm_kind = GemmKind::kUniversal;
|
||||
|
||||
description_.tile_description.threadblock_shape = make_Coord(
|
||||
Operator::ThreadblockShape::kM,
|
||||
Operator::ThreadblockShape::kN,
|
||||
Operator::ThreadblockShape::kK);
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
description_.tile_description.cluster_shape = make_Coord(
|
||||
Operator::ClusterShape::kM,
|
||||
Operator::ClusterShape::kN,
|
||||
Operator::ClusterShape::kK);
|
||||
}
|
||||
|
||||
description_.tile_description.threadblock_stages = Operator::kStages;
|
||||
|
||||
description_.tile_description.warp_count = make_Coord(
|
||||
Operator::WarpCount::kM,
|
||||
Operator::WarpCount::kN,
|
||||
Operator::WarpCount::kK);
|
||||
|
||||
description_.tile_description.math_instruction.instruction_shape = make_Coord(
|
||||
Operator::InstructionShape::kM,
|
||||
Operator::InstructionShape::kN,
|
||||
Operator::InstructionShape::kK);
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.opcode_class =
|
||||
OpcodeClassMap<typename Operator::OperatorClass>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.math_operation =
|
||||
MathOperationMap<typename Operator::MathOperator>::kId;
|
||||
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
|
||||
|
||||
description_.tile_description.maximum_compute_capability =
|
||||
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
|
||||
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
|
||||
description_.D = make_TensorDescription<ElementD, LayoutD>(Operator::kAlignmentD);
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.split_k_mode = SplitKMode::kNone;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
virtual OperationDescription const & description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
BlockwiseGemmDescription const& get_gemm_description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
|
||||
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
|
||||
// Do nothing here and construct kernel arguments in update_arguments_ instead
|
||||
// We also cannot construct TMA descriptors without all the arguments available
|
||||
|
||||
operator_args.mode = configuration->mode;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
template<class FusionArgs, class = void>
|
||||
struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) {
|
||||
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
|
||||
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
|
||||
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
OperatorArguments &operator_args,
|
||||
BlockwiseGemmArguments const *arguments) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread, *arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
operator_args.problem_shape = cute::make_shape(
|
||||
arguments->problem_size.m(),
|
||||
arguments->problem_size.n(),
|
||||
arguments->problem_size.k(),
|
||||
arguments->batch_count);
|
||||
|
||||
// update arguments
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
|
||||
|
||||
std::unordered_map<RuntimeDatatype, cute::UMMA::MXF8F6F4Format> mapping = {
|
||||
{RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3},
|
||||
{RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2},
|
||||
{RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2},
|
||||
{RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1}
|
||||
};
|
||||
|
||||
auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a);
|
||||
auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b);
|
||||
|
||||
if (iter_runtime_a != mapping.end()) {
|
||||
operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second;
|
||||
} else {
|
||||
assert("invalid runtime argument for datatype A!");
|
||||
}
|
||||
|
||||
if (iter_runtime_b != mapping.end()) {
|
||||
operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second;
|
||||
} else {
|
||||
assert("invalid runtime argument for datatype B!");
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
|
||||
}
|
||||
operator_args.mainloop.ptr_SFA = static_cast<ElementSFA const *>(arguments->SFA);
|
||||
operator_args.mainloop.ptr_SFB = static_cast<ElementSFB const *>(arguments->SFB);
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
|
||||
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
arguments->lda, arguments->batch_stride_A);
|
||||
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
|
||||
arguments->ldb, arguments->batch_stride_B);
|
||||
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
|
||||
arguments->ldc, arguments->batch_stride_C);
|
||||
operator_args.epilogue.dD = operator_args.epilogue.dC;
|
||||
|
||||
operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape);
|
||||
operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape);
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
operator_args.hw_info.sm_count = arguments->sm_count;
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
|
||||
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
|
||||
}
|
||||
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
|
||||
using Enum_t = decltype(operator_args.scheduler.raster_order);
|
||||
switch (arguments->raster_order) {
|
||||
case RasterOrder::kAlongN:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongN;
|
||||
break;
|
||||
case RasterOrder::kAlongM:
|
||||
operator_args.scheduler.raster_order = Enum_t::AlongM;
|
||||
break;
|
||||
default:
|
||||
operator_args.scheduler.raster_order = Enum_t::Heuristic;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
||||
operator_args.scheduler.splits = arguments->split_k_slices;
|
||||
}
|
||||
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape = dim3(
|
||||
arguments->cluster_shape.m(),
|
||||
arguments->cluster_shape.n(),
|
||||
arguments->cluster_shape.k());
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(
|
||||
arguments->cluster_shape_fallback.m(),
|
||||
arguments->cluster_shape_fallback.n(),
|
||||
arguments->cluster_shape_fallback.k());
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement(
|
||||
void const *configuration_ptr, void const *arguments_ptr) const override {
|
||||
|
||||
GemmUniversalConfiguration const *configuration =
|
||||
static_cast<GemmUniversalConfiguration const *>(configuration_ptr);
|
||||
BlockwiseGemmArguments const *arguments =
|
||||
static_cast<BlockwiseGemmArguments const *>(arguments_ptr);
|
||||
|
||||
if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(args, arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// can_implement rules may need access to problem shape
|
||||
args.problem_shape = cute::make_shape(
|
||||
configuration->problem_size.m(),
|
||||
configuration->problem_size.n(),
|
||||
configuration->problem_size.k(),
|
||||
configuration->batch_count);
|
||||
|
||||
return Operator::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
uint64_t get_host_workspace_size(void const *configuration) const override {
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr,void const *arguments_ptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(
|
||||
args, static_cast<BlockwiseGemmArguments const *>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
Status initialize(
|
||||
void const *configuration_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
Operator *op = new (host_workspace) Operator;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
Status initialize_with_profiler_workspace(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace,
|
||||
uint8_t **profiler_workspaces,
|
||||
int problem_count_from_profiler,
|
||||
cudaStream_t stream = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel
|
||||
Status run(
|
||||
void const *arguments_ptr,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
Status status = update_arguments_(args, static_cast<BlockwiseGemmArguments const *>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator *op = static_cast<Operator *>(host_workspace);
|
||||
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
|
||||
status = op->run(args, device_workspace, stream, nullptr, static_cast<BlockwiseGemmArguments const *>(arguments_ptr)->use_pdl);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::library
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -72,19 +72,6 @@ public:
|
||||
|
||||
this->description_.gemm = GemmOperation3xBase<Operator_>::description_;
|
||||
this->description_.tile_description = this->description_.gemm.tile_description;
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster_dims(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
@ -102,7 +89,6 @@ public:
|
||||
|
||||
protected:
|
||||
library::GroupedGemmDescription description_;
|
||||
int max_active_clusters;
|
||||
|
||||
Status initialize_strides(GemmGroupedConfiguration const& config) const {
|
||||
auto const num_groups = config.problem_count;
|
||||
@ -182,7 +168,7 @@ protected:
|
||||
|
||||
operator_args.hw_info.sm_count = arguments.sm_count;
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
operator_args.hw_info.max_active_clusters = max_active_clusters;
|
||||
operator_args.hw_info.max_active_clusters = arguments.max_active_clusters;
|
||||
}
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape =
|
||||
@ -343,6 +329,47 @@ public:
|
||||
status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl);
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
// Set arguments that should only be set once before verifying or profiling the kernel.
|
||||
// This should encompass any expensive operations that don't vary from run to run
|
||||
// (e.g., max_active_clusters).
|
||||
Status initialize_with_arguments(void* arguments_ptr) const override {
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability < 90) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
GemmGroupedArguments* args = static_cast<GemmGroupedArguments*>(arguments_ptr);
|
||||
|
||||
dim3 cluster_dims;
|
||||
if constexpr (cute::is_static_v<typename Operator::GemmKernel::ClusterShape>) {
|
||||
cluster_dims = dim3(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{})
|
||||
);
|
||||
}
|
||||
else {
|
||||
cluster_dims = dim3(
|
||||
args->cluster_shape.m(),
|
||||
args->cluster_shape.n(),
|
||||
args->cluster_shape.k()
|
||||
);
|
||||
}
|
||||
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
|
||||
if (args->max_active_clusters == 0) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Operator_>
|
||||
@ -375,6 +402,7 @@ public:
|
||||
: GroupedGemmOperation3xBase<Operator_>(name) {
|
||||
|
||||
BlockScaleDescription block_scaled_desc{};
|
||||
block_scaled_desc.kind = OperationKind::kBlockScaledGemm;
|
||||
block_scaled_desc.SFA.element = NumericTypeMap<ElementSFA>::kId;
|
||||
block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor;
|
||||
block_scaled_desc.SFA.alignment = 128;
|
||||
@ -387,7 +415,9 @@ public:
|
||||
block_scaled_desc.SFB.log_extent_range = 32;
|
||||
block_scaled_desc.SFB.log_stride_range = 32;
|
||||
|
||||
block_scaled_desc.SFVecSize = SFVecSize;
|
||||
block_scaled_desc.SFMVecSize = 1;
|
||||
block_scaled_desc.SFNVecSize = 1;
|
||||
block_scaled_desc.SFKVecSize = SFVecSize;
|
||||
|
||||
block_scaled_desc.SFD = make_TensorDescription<ElementSFD, LayoutSFD>(128);
|
||||
block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize;
|
||||
@ -555,4 +585,206 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Operator_>
|
||||
class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase<Operator_> {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
using ElementD = typename Operator::ElementD;
|
||||
using LayoutD = typename Operator::LayoutD;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using CollectiveMainloop = typename Operator::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
using ElementSFA = typename Operator::ElementAccumulator;
|
||||
using ElementSFB = typename Operator::ElementAccumulator;
|
||||
|
||||
using TiledMma = typename Operator::CollectiveMainloop::TiledMma;
|
||||
|
||||
GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm")
|
||||
: GroupedGemmOperation3xBase<Operator_>(name) {
|
||||
|
||||
BlockScaleDescription blockwise_desc{};
|
||||
blockwise_desc.kind = OperationKind::kBlockwiseGemm;
|
||||
blockwise_desc.SFA.element = NumericTypeMap<ElementSFA>::kId;
|
||||
blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ?
|
||||
LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
|
||||
blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA;
|
||||
blockwise_desc.SFA.log_extent_range = 32;
|
||||
blockwise_desc.SFA.log_stride_range = 32;
|
||||
|
||||
blockwise_desc.SFB.element = NumericTypeMap<ElementSFB>::kId;
|
||||
blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ?
|
||||
LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor;
|
||||
blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA;
|
||||
blockwise_desc.SFB.log_extent_range = 32;
|
||||
blockwise_desc.SFB.log_stride_range = 32;
|
||||
|
||||
blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM;
|
||||
blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN;
|
||||
blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK;
|
||||
|
||||
blockwise_desc.EpilogueSFVecSize = 0;
|
||||
|
||||
this->description_.block_scales = blockwise_desc;
|
||||
}
|
||||
|
||||
~GroupedBlockwiseGemmUniversal3xOperation() override = default;
|
||||
|
||||
mutable CudaBuffer layout_SFA_device;
|
||||
mutable CudaBuffer layout_SFB_device;
|
||||
|
||||
protected:
|
||||
template <class FusionArgs, class = void> struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template <class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status
|
||||
update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) {
|
||||
return GroupedGemmOperation3xBase<Operator>::update_fusion_args(fusion_args, arguments);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr)
|
||||
const override {
|
||||
GroupedGemmBlockwiseArguments const* arguments =
|
||||
static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr);
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(args, arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operator::can_implement(args);
|
||||
return status;
|
||||
}
|
||||
|
||||
Status update_arguments_(
|
||||
OperatorArguments& operator_args,
|
||||
GroupedGemmBlockwiseArguments const* arguments) const {
|
||||
Status status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread,
|
||||
*arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
operator_args.mainloop.ptr_SFA =
|
||||
static_cast<const typename Operator::GemmKernel::ElementAccumulator**>(arguments->SFA);
|
||||
operator_args.mainloop.ptr_SFB =
|
||||
static_cast<const typename Operator::GemmKernel::ElementAccumulator**>(arguments->SFB);
|
||||
|
||||
operator_args.mainloop.layout_SFA =
|
||||
static_cast<typename CollectiveMainloop::InternalLayoutSFA*>(this->layout_SFA_device.data());
|
||||
operator_args.mainloop.layout_SFB =
|
||||
static_cast<typename CollectiveMainloop::InternalLayoutSFB*>(this->layout_SFB_device.data());
|
||||
|
||||
return this->update_arguments_base(operator_args, *arguments);
|
||||
}
|
||||
|
||||
uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr)
|
||||
const override {
|
||||
|
||||
OperatorArguments args;
|
||||
auto status =
|
||||
update_arguments_(args, static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
/// **** CAUTION ****
|
||||
/// Must be called when lda, ldb, ldc, or ldd change.
|
||||
/// The CUTLASS library stores the operations in a type-
|
||||
/// erased manifest. Therefore, only this class knows
|
||||
/// the type of strideA, strideB, strideC, and strideD.
|
||||
/// Since grouped GEMM needs to allocate storage for
|
||||
/// the strides on device, the concrete type of the stride
|
||||
/// must be known in order to copy in the correct memory
|
||||
/// layout on device.
|
||||
Status initialize(
|
||||
void const* configuration_ptr,
|
||||
void* host_workspace,
|
||||
void* device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
auto const& config = *static_cast<GemmGroupedConfiguration const*>(configuration_ptr);
|
||||
auto status = this->initialize_strides(config);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto num_groups = config.problem_count;
|
||||
this->layout_SFA_device =
|
||||
CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups);
|
||||
this->layout_SFB_device =
|
||||
CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups);
|
||||
auto layout_SFA_host = std::vector<typename CollectiveMainloop::InternalLayoutSFA>(num_groups);
|
||||
auto layout_SFB_host = std::vector<typename CollectiveMainloop::InternalLayoutSFB>(num_groups);
|
||||
|
||||
for (int group_idx = 0; group_idx < num_groups; group_idx++) {
|
||||
auto const& shape = config.problem_sizes_3x_host[group_idx];
|
||||
auto M = get<0>(shape);
|
||||
auto N = get<1>(shape);
|
||||
auto K = get<2>(shape);
|
||||
|
||||
auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
layout_SFA_host[group_idx] = layout_SFA;
|
||||
layout_SFB_host[group_idx] = layout_SFB;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
this->layout_SFA_device.data(),
|
||||
layout_SFA_host.data(),
|
||||
sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
this->layout_SFB_device.data(),
|
||||
layout_SFB_host.data(),
|
||||
sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
Operator* op = new (host_workspace) Operator;
|
||||
return status;
|
||||
}
|
||||
|
||||
/// **** CAUTION ****
|
||||
/// initialize() must be called if lda, ldb, ldc, or ldd change.
|
||||
Status run(
|
||||
void const* arguments_ptr,
|
||||
void* host_workspace,
|
||||
void* device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments operator_args;
|
||||
auto const& args = *static_cast<GroupedGemmBlockwiseArguments const*>(arguments_ptr);
|
||||
|
||||
Status status = update_arguments_(operator_args, &args);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator* op = static_cast<Operator*>(host_workspace);
|
||||
status = op->run(operator_args, device_workspace, stream, nullptr);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass::library
|
||||
|
||||
@ -86,6 +86,41 @@ void OperationTable::append(Manifest const &manifest) {
|
||||
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
|
||||
if (desc.kind == OperationKind::kBlockwiseGemm) {
|
||||
BlockwiseGemmDescription const &gemm_desc = static_cast<BlockwiseGemmDescription const &>(desc);
|
||||
|
||||
BlockwiseGemmFunctionalKey functional_key(
|
||||
gemm_desc.provider,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.SFA.element,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
gemm_desc.SFMVecSize,
|
||||
gemm_desc.SFNVecSize,
|
||||
gemm_desc.SFKVecSize
|
||||
);
|
||||
|
||||
Operation const *op = operation.get();
|
||||
|
||||
int cc = gemm_desc.tile_description.minimum_compute_capability;
|
||||
|
||||
int alignment = std::max(std::max(
|
||||
gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment);
|
||||
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
|
||||
blockwise_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
|
||||
// insert all gemm operation into operation table
|
||||
if (desc.kind == OperationKind::kGemm) {
|
||||
@ -157,29 +192,57 @@ void OperationTable::append(Manifest const &manifest) {
|
||||
}
|
||||
else {
|
||||
const BlockScaleDescription &block_scale_desc = grouped_gemm_desc.block_scales.value();
|
||||
BlockScaledGemmFunctionalKey functional_key(
|
||||
gemm_desc.provider,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
block_scale_desc.SFA.element,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
block_scale_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFD.element,
|
||||
block_scale_desc.SFD.layout,
|
||||
block_scale_desc.SFVecSize,
|
||||
block_scale_desc.EpilogueSFVecSize
|
||||
);
|
||||
if (block_scale_desc.kind == OperationKind::kBlockScaledGemm) {
|
||||
|
||||
BlockScaledGemmFunctionalKey functional_key(
|
||||
gemm_desc.provider,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
block_scale_desc.SFA.element,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
block_scale_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFD.element,
|
||||
block_scale_desc.SFD.layout,
|
||||
block_scale_desc.SFKVecSize,
|
||||
block_scale_desc.EpilogueSFVecSize
|
||||
);
|
||||
|
||||
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
block_scaled_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
else {
|
||||
assert(block_scale_desc.kind == OperationKind::kBlockwiseGemm);
|
||||
BlockwiseGemmFunctionalKey functional_key(
|
||||
gemm_desc.provider,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
block_scale_desc.SFA.element,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
block_scale_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFMVecSize,
|
||||
block_scale_desc.SFNVecSize,
|
||||
block_scale_desc.SFKVecSize
|
||||
);
|
||||
|
||||
blockwise_gemm_operations[functional_key][preference_key].push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
58
tools/library/src/reference/blockwise_gemm_fp8_bf16out.cu
Normal file
58
tools/library/src/reference/blockwise_gemm_fp8_bf16out.cu
Normal file
@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "blockwise_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_blockwise_gemm_reference_operations_bf16out(Manifest &manifest) {
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, bfloat16_t>(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<bfloat16_t, bfloat16_t>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
58
tools/library/src/reference/blockwise_gemm_fp8_fp16out.cu
Normal file
58
tools/library/src/reference/blockwise_gemm_fp8_fp16out.cu
Normal file
@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "blockwise_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_blockwise_gemm_reference_operations_fp16out(Manifest &manifest) {
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, half_t>(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<half_t, half_t>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
58
tools/library/src/reference/blockwise_gemm_fp8_fp32out.cu
Normal file
58
tools/library/src/reference/blockwise_gemm_fp8_fp32out.cu
Normal file
@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "blockwise_gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_blockwise_gemm_reference_operations_fp32out(Manifest &manifest) {
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<void, float>(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_given_C_and_D<float, float>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
664
tools/library/src/reference/blockwise_gemm_reference_operation.h
Normal file
664
tools/library/src/reference/blockwise_gemm_reference_operation.h
Normal file
@ -0,0 +1,664 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "library_internal.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/detail/blockwise_scale_layout.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
Provider Provider_,
|
||||
typename ElementA_,
|
||||
typename LayoutA_,
|
||||
typename LayoutSFA_,
|
||||
typename ElementSFA_,
|
||||
typename ElementB_,
|
||||
typename LayoutB_,
|
||||
typename LayoutSFB_,
|
||||
typename ElementSFB_,
|
||||
typename ElementC_,
|
||||
typename LayoutC_,
|
||||
typename ElementCompute_,
|
||||
typename ElementAccumulator_ = ElementCompute_,
|
||||
typename ElementD_ = ElementC_,
|
||||
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
|
||||
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
||||
>
|
||||
class BlockwiseGemmReferenceOperation : public Operation {
|
||||
public:
|
||||
static Provider const kProvider = Provider_;
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using ElementSFA = ElementSFA_;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using ElementSFB = ElementSFB_;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ConvertOp = ConvertOp_;
|
||||
using InnerProductOp = InnerProductOp_;
|
||||
|
||||
protected:
|
||||
|
||||
/// Storage for the name string
|
||||
std::string name_;
|
||||
|
||||
///
|
||||
BlockwiseGemmDescription description_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_)
|
||||
: SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) {
|
||||
|
||||
// Basic information
|
||||
description_.provider = kProvider;
|
||||
description_.kind = OperationKind::kBlockwiseGemm;
|
||||
description_.gemm_kind = GemmKind::kUniversal;
|
||||
|
||||
// Tensor description
|
||||
description_.A = make_TensorDescription<ElementA, LayoutA>();
|
||||
description_.SFA = make_TensorDescription<ElementSFA, LayoutSFA_>();
|
||||
description_.B = make_TensorDescription<ElementB, LayoutB>();
|
||||
description_.SFB = make_TensorDescription<ElementSFB, LayoutSFB_>();
|
||||
description_.C = make_TensorDescription<ElementC, LayoutC>();
|
||||
description_.D = make_TensorDescription<ElementD, LayoutC>();
|
||||
|
||||
// Epilogue compute and accumulator type description
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
|
||||
// Compute capability for gemm reference
|
||||
description_.tile_description.minimum_compute_capability =
|
||||
(kProvider == Provider::kReferenceDevice ? 50 : 0);
|
||||
|
||||
description_.tile_description.maximum_compute_capability = 1024;
|
||||
|
||||
description_.SFMVecSize = SFMVecSize;
|
||||
description_.SFNVecSize = SFNVecSize;
|
||||
description_.SFKVecSize = SFKVecSize;
|
||||
|
||||
// Procedural name
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "gemm"
|
||||
<< "_reference_" << to_string(description_.provider)
|
||||
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
|
||||
<< "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout)
|
||||
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
|
||||
<< "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout)
|
||||
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
|
||||
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
|
||||
|
||||
name_ = ss.str();
|
||||
|
||||
description_.name = name_.c_str();
|
||||
|
||||
// Epilogue compute and accumulator type description
|
||||
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
||||
|
||||
description_.tile_description.math_instruction.element_accumulator =
|
||||
NumericTypeMap<ElementAccumulator>::kId;
|
||||
}
|
||||
|
||||
/// Returns the description of the GEMM operation
|
||||
virtual OperationDescription const & description() const {
|
||||
return description_;
|
||||
}
|
||||
|
||||
virtual Status can_implement(
|
||||
void const *configuration,
|
||||
void const *arguments) const {
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
virtual uint64_t get_host_workspace_size(
|
||||
void const *configuration) const {
|
||||
|
||||
return sizeof(GemmUniversalConfiguration);
|
||||
}
|
||||
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration,
|
||||
void const *arguments = nullptr) const {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual Status initialize(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
virtual Status run(
|
||||
void const *arguments,
|
||||
void *host_workspace,
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
using namespace cute;
|
||||
|
||||
BlockwiseGemmArguments const &args = *static_cast<BlockwiseGemmArguments const *>(arguments);
|
||||
|
||||
// Construct cute::Tensor A/B/C
|
||||
|
||||
int M = args.problem_size.m();
|
||||
int N = args.problem_size.n();
|
||||
int K = args.problem_size.k();
|
||||
int L = args.batch_count;
|
||||
|
||||
auto problem_shape_MNKL = cute::make_shape(M, N, K, L);
|
||||
|
||||
auto alpha = *(static_cast<ElementCompute const*>(args.alpha));
|
||||
auto beta = *(static_cast<ElementCompute const*>(args.beta));
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
|
||||
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutC>;
|
||||
|
||||
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>;
|
||||
auto A = cute::make_tensor(static_cast<ElementA const*>(args.A),
|
||||
cute::make_layout(cute::make_shape(M, K, L), stride_a));
|
||||
auto SfA = make_tensor(static_cast<ElementSFA const*>(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
|
||||
|
||||
auto B = cute::make_tensor(static_cast<ElementB const*>(args.B),
|
||||
cute::make_layout(cute::make_shape(N, K, L), stride_b));
|
||||
auto SfB = make_tensor(static_cast<ElementSFB const*>(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
|
||||
|
||||
auto C = [&]() {
|
||||
if constexpr (not is_same_v<ElementC, void>) {
|
||||
return cute::make_tensor(static_cast<ElementC const*>(args.C),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
||||
}
|
||||
else {
|
||||
return cute::make_tensor(static_cast<ElementD const*>(nullptr),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
||||
}
|
||||
}();
|
||||
|
||||
auto D = cute::make_tensor(static_cast<ElementD *>(args.D),
|
||||
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(SfA),
|
||||
decltype(B), decltype(SfB)>
|
||||
mainloop_params{A, SfA, B, SfB};
|
||||
|
||||
// W/O SF generation
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute,
|
||||
decltype(C), decltype(D)>
|
||||
epilogue_params{alpha, beta, C, D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
private:
|
||||
int SFMVecSize;
|
||||
int SFNVecSize;
|
||||
int SFKVecSize;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA_,
|
||||
typename ElementSFA_,
|
||||
typename ElementB_,
|
||||
typename ElementSFB_,
|
||||
typename ElementC_,
|
||||
typename ElementCompute_,
|
||||
typename ElementAccumulator_ = ElementCompute_,
|
||||
typename ElementD_ = ElementC_,
|
||||
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
|
||||
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
||||
>
|
||||
void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) {
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
manifest.append(new BlockwiseGemmReferenceOperation<
|
||||
Provider::kReferenceHost,
|
||||
ElementA_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFA_,
|
||||
ElementB_,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementSFB_,
|
||||
ElementC_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementCompute_,
|
||||
ElementAccumulator_,
|
||||
ElementD_,
|
||||
ConvertOp_,
|
||||
InnerProductOp_
|
||||
>(SFMVecSize, SFNVecSize, SFKVecSize));
|
||||
|
||||
|
||||
}
|
||||
|
||||
template<class ElementC,
|
||||
class ElementD>
|
||||
void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) {
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 1 , 128);
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 128, 128);
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 128, 128, 128);
|
||||
|
||||
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 1 , 128);
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 128, 128);
|
||||
make_blockwise_gemm<
|
||||
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 128, 128, 128);
|
||||
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 1 , 128);
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 128, 128);
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 128, 128, 128);
|
||||
|
||||
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 1 , 128);
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 1, 128, 128);
|
||||
make_blockwise_gemm<
|
||||
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
|
||||
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
|
||||
>(manifest, 128, 128, 128);
|
||||
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -64,6 +64,10 @@ void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
|
||||
void initialize_blockwise_gemm_reference_operations_fp32out(Manifest &manifest);
|
||||
void initialize_blockwise_gemm_reference_operations_fp16out(Manifest &manifest);
|
||||
void initialize_blockwise_gemm_reference_operations_bf16out(Manifest &manifest);
|
||||
|
||||
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||
@ -113,6 +117,9 @@ void initialize_reference_operations(Manifest &manifest) {
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_fp32out(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_fp16out(manifest);
|
||||
initialize_blockwise_gemm_reference_operations_bf16out(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -334,6 +334,7 @@ static struct {
|
||||
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
|
||||
{"gemm", "Gemm", OperationKind::kGemm},
|
||||
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
|
||||
{"blockwise_gemm", "blockwiseGemm", OperationKind::kBlockwiseGemm},
|
||||
{"rank_k", "RankK", OperationKind::kRankK},
|
||||
{"rank_2k", "Rank2K", OperationKind::kRank2K},
|
||||
{"trmm", "Trmm", OperationKind::kTrmm},
|
||||
|
||||
@ -48,6 +48,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
|
||||
src/gemm_operation_profiler.cu
|
||||
src/grouped_gemm_operation_profiler.cu
|
||||
src/block_scaled_gemm_operation_profiler.cu
|
||||
src/blockwise_gemm_operation_profiler.cu
|
||||
src/rank_k_operation_profiler.cu
|
||||
src/rank_2k_operation_profiler.cu
|
||||
src/trmm_operation_profiler.cu
|
||||
|
||||
@ -0,0 +1,290 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Blockscale Gemm Profiler
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
// Profiler includes
|
||||
#include "options.h"
|
||||
#include "device_context.h"
|
||||
#include "operation_profiler.h"
|
||||
#include "performance_result.h"
|
||||
#include "problem_space.h"
|
||||
#include "reduction_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Abstract base class for each math function
|
||||
class BlockwiseGemmOperationProfiler : public OperationProfiler {
|
||||
public:
|
||||
|
||||
/// Problem structure obtained from problem space
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm};
|
||||
|
||||
int64_t m{16};
|
||||
int64_t n{16};
|
||||
int64_t k{16};
|
||||
|
||||
int64_t sf_vec_m{0};
|
||||
int64_t sf_vec_n{0};
|
||||
int64_t sf_vec_k{0};
|
||||
|
||||
int cluster_m{1};
|
||||
int cluster_n{1};
|
||||
int cluster_k{1};
|
||||
int cluster_m_fallback{1};
|
||||
int cluster_n_fallback{1};
|
||||
int cluster_k_fallback{1};
|
||||
|
||||
|
||||
int64_t lda{0};
|
||||
int64_t ldb{0};
|
||||
int64_t ldc{0};
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone};
|
||||
int split_k_slices{1};
|
||||
int batch_count{1};
|
||||
|
||||
cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic};
|
||||
int swizzle_size{1};
|
||||
|
||||
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_a{};
|
||||
cutlass::library::RuntimeDatatype runtime_input_datatype_b{};
|
||||
|
||||
|
||||
// gemm with parallel interleaved reduction
|
||||
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||
std::vector<uint8_t> alpha_one;
|
||||
std::vector<uint8_t> beta_zero;
|
||||
|
||||
bool use_pdl{false};
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Parses the problem
|
||||
Status parse(
|
||||
library::BlockwiseGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Total number of bytes loaded
|
||||
int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const;
|
||||
|
||||
/// Total number of flops computed
|
||||
int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const;
|
||||
|
||||
/// Initializes a performance result
|
||||
void initialize_result(
|
||||
PerformanceResult &result,
|
||||
library::BlockwiseGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space);
|
||||
};
|
||||
|
||||
/// Workspace used
|
||||
struct GemmWorkspace {
|
||||
|
||||
DeviceAllocation *A{nullptr};
|
||||
DeviceAllocation *SFA{nullptr};
|
||||
DeviceAllocation *B{nullptr};
|
||||
DeviceAllocation *SFB{nullptr};
|
||||
DeviceAllocation *C{nullptr};
|
||||
DeviceAllocation *Computed{nullptr};
|
||||
DeviceAllocation *Reference{nullptr};
|
||||
|
||||
/// Number of copies of the problem workspace which are visited sequentially during
|
||||
/// profiling to avoid camping in the last level cache.
|
||||
int problem_count{1};
|
||||
|
||||
library::GemmUniversalConfiguration configuration;
|
||||
library::BlockwiseGemmArguments arguments;
|
||||
|
||||
/// Buffer used for the operation's host workspace
|
||||
std::vector<uint8_t> host_workspace;
|
||||
|
||||
/// Buffer used for the operations' device workspace
|
||||
DeviceAllocation device_workspace;
|
||||
|
||||
/// Library configuration and arguments for reduction operator
|
||||
library::ReductionConfiguration reduction_configuration;
|
||||
library::ReductionArguments reduction_arguments;
|
||||
|
||||
/// Buffer used for the cutlass reduction operations' host workspace
|
||||
std::vector<uint8_t> reduction_host_workspace;
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// GEMM problem obtained from problem space
|
||||
GemmProblem problem_;
|
||||
|
||||
/// Device memory allocations
|
||||
GemmWorkspace gemm_workspace_;
|
||||
|
||||
/// CUTLASS parallel reduction operation to follow this* gemm operation
|
||||
library::Operation const *reduction_op_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
BlockwiseGemmOperationProfiler(Options const &options);
|
||||
|
||||
/// Destructor
|
||||
virtual ~BlockwiseGemmOperationProfiler();
|
||||
|
||||
GemmProblem const& problem() const { return problem_; }
|
||||
|
||||
/// Prints usage statement for the math function
|
||||
virtual void print_usage(std::ostream &out) const;
|
||||
|
||||
/// Prints examples
|
||||
virtual void print_examples(std::ostream &out) const;
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
virtual Status initialize_configuration(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Initializes workspace
|
||||
virtual Status initialize_workspace(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
virtual bool verify_cutlass(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Measures performance results
|
||||
virtual bool profile(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
protected:
|
||||
|
||||
/// Initializes the performance result
|
||||
void initialize_result_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
library::BlockwiseGemmDescription const &operation_desc,
|
||||
ProblemSpace const &problem_space);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
bool verify_with_cublas_(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool verify_with_reference_(
|
||||
Options const &options,
|
||||
PerformanceReport &report,
|
||||
DeviceContext &device_context,
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem,
|
||||
cutlass::library::NumericTypeID element_A,
|
||||
cutlass::library::NumericTypeID element_B);
|
||||
|
||||
/// Method to profile a CUTLASS Operation
|
||||
Status profile_cutlass_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
library::Operation const *operation,
|
||||
void *arguments,
|
||||
void *host_workspace,
|
||||
void *device_workspace);
|
||||
|
||||
/// Initialize reduction problem dimensions and library::Operation
|
||||
bool initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -198,6 +198,11 @@ private:
|
||||
arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data();
|
||||
arguments.norm_constant = block_scaled_ws.norm_constant->data();
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto& block_scaled_ws = gemm_workspace_.block_scales.value();
|
||||
arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data();
|
||||
arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -208,6 +213,7 @@ protected:
|
||||
GroupedGemmWorkspace gemm_workspace_;
|
||||
|
||||
bool is_block_scaled{false};
|
||||
bool is_blockwise{false};
|
||||
|
||||
public:
|
||||
GroupedGemmOperationProfiler(Options const& options);
|
||||
|
||||
@ -437,9 +437,11 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
1299
tools/profiler/src/blockwise_gemm_operation_profiler.cu
Normal file
1299
tools/profiler/src/blockwise_gemm_operation_profiler.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,7 @@
|
||||
|
||||
// Profiler includes
|
||||
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/blockwise_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv2d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv3d_operation_profiler.h"
|
||||
#include "cutlass/profiler/cutlass_profiler.h"
|
||||
@ -64,6 +65,8 @@ CutlassProfiler::CutlassProfiler(
|
||||
|
||||
operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new BlockwiseGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new Conv2dOperationProfiler(options));
|
||||
|
||||
@ -440,10 +440,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "n", problem_space, n);
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
@ -459,9 +460,11 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result(
|
||||
set_argument(result, "problem-sizes", problem_space, ss.str());
|
||||
}
|
||||
|
||||
set_argument(result, "cluster_m", problem_space, cluster_m);
|
||||
set_argument(result, "cluster_n", problem_space, cluster_n);
|
||||
set_argument(result, "cluster_k", problem_space, cluster_k);
|
||||
auto cluster_shape = operation_desc.gemm.tile_description.cluster_shape;
|
||||
auto is_dynamic = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0;
|
||||
set_argument(result, "cluster_m", problem_space, is_dynamic ? this->cluster_m : cluster_shape.m());
|
||||
set_argument(result, "cluster_n", problem_space, is_dynamic ? this->cluster_n : cluster_shape.n());
|
||||
set_argument(result, "cluster_k", problem_space, is_dynamic ? this->cluster_k : cluster_shape.k());
|
||||
set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback);
|
||||
set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback);
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
@ -497,10 +500,22 @@ Status GroupedGemmOperationProfiler::initialize_configuration(
|
||||
// We distinguish between block scaled and non-block scaled operations by looking at the kernel
|
||||
// name, which tells us what reference kernel to use, which arguments to pass to the operation
|
||||
// etc. This avoids creating yet another OperationProfiler with a lot of boilerplate in it.
|
||||
|
||||
std::string sf_tuple = "\\d+x\\d+";
|
||||
std::string datatypes_regex = "\\w?f\\d+|e\\dm\\d"; // bf16 | f16 | f32 | e4m3 | ...
|
||||
std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" +
|
||||
datatypes_regex + ")_" + sf_tuple + "(" +
|
||||
datatypes_regex + ")x(" + datatypes_regex + ")";
|
||||
|
||||
|
||||
if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) {
|
||||
is_block_scaled = true;
|
||||
gemm_workspace_.block_scales = BlockScalingWorkspace{};
|
||||
}
|
||||
else if (std::regex_search(operation_desc.gemm.name, std::regex(blockwise_regex_string))) {
|
||||
is_blockwise = true;
|
||||
gemm_workspace_.block_scales = BlockScalingWorkspace{};
|
||||
}
|
||||
else {
|
||||
is_block_scaled = false;
|
||||
gemm_workspace_.block_scales = std::nullopt;
|
||||
@ -605,6 +620,12 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
block_scaling_ws.SFD_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFD_reference_ptr_array_host.resize(num_groups);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto& block_scaling_ws = gemm_workspace_.block_scales.value();
|
||||
block_scaling_ws.SFA_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFB_ptr_array_host.resize(num_groups);
|
||||
block_scaling_ws.SFC_ptr_array_host.resize(num_groups);
|
||||
}
|
||||
static_assert(sizeof(void*) == 8); // allocating blocks for pointers, so verify pointer size
|
||||
// ldx
|
||||
gemm_workspace_.lda_array_device =
|
||||
@ -698,7 +719,7 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
int sfa_m = round_up(int(problem_.m(group_idx)), 128);
|
||||
int sfb_n = round_up(int(problem_.n(group_idx)), 128);
|
||||
int sfa_sfb_k =
|
||||
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFVecSize), 4);
|
||||
round_up(ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize), 4);
|
||||
|
||||
int sfd_m =
|
||||
block_scale_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor
|
||||
@ -760,6 +781,37 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
block_scale_ws.SFD_ptr_array_host[group_idx]->fill_device(0);
|
||||
}
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto const block_scale_desc = operation_desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
int sfa_m = ceil_div(int(problem_.m(group_idx)), block_scale_desc.SFMVecSize);
|
||||
int sfb_n = ceil_div(int(problem_.n(group_idx)), block_scale_desc.SFNVecSize);
|
||||
int sfa_sfb_k = ceil_div(int(problem_.k(group_idx)), block_scale_desc.SFKVecSize);
|
||||
|
||||
block_scale_ws.SFA_ptr_array_host[group_idx] =
|
||||
device_context.allocate_and_initialize_tensor(
|
||||
options,
|
||||
"SFA_" + std::to_string(group_idx),
|
||||
block_scale_desc.SFA.element,
|
||||
block_scale_desc.SFA.layout,
|
||||
{sfa_m, sfa_sfb_k},
|
||||
{sfa_m},
|
||||
gemm_workspace_.problem_count,
|
||||
seed_shift++,
|
||||
0);
|
||||
|
||||
block_scale_ws.SFB_ptr_array_host[group_idx] =
|
||||
device_context.allocate_and_initialize_tensor(
|
||||
options,
|
||||
"SFB_" + std::to_string(group_idx),
|
||||
block_scale_desc.SFB.element,
|
||||
block_scale_desc.SFB.layout,
|
||||
{sfa_sfb_k, sfb_n},
|
||||
{sfb_n},
|
||||
gemm_workspace_.problem_count,
|
||||
seed_shift++,
|
||||
0);
|
||||
}
|
||||
}
|
||||
|
||||
// takes the allocated tensors and initializes an array of pointers per problem in the workspace
|
||||
@ -825,6 +877,18 @@ Status GroupedGemmOperationProfiler::initialize_workspace(
|
||||
0 // device_index
|
||||
);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
create_dev_ptr_array_all_workspace(
|
||||
block_scale_ws.SFA_ptr_array_device,
|
||||
block_scale_ws.SFA_ptr_array_host,
|
||||
"SFA");
|
||||
create_dev_ptr_array_all_workspace(
|
||||
block_scale_ws.SFB_ptr_array_device,
|
||||
block_scale_ws.SFB_ptr_array_host,
|
||||
"SFB");
|
||||
}
|
||||
|
||||
init_arguments(options);
|
||||
}
|
||||
|
||||
@ -896,6 +960,11 @@ bool GroupedGemmOperationProfiler::verify_cutlass(
|
||||
init_arguments(options);
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
results_.back().status = underlying_operation->run(
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
@ -998,7 +1067,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
}
|
||||
|
||||
// we only have a block scaled reference kernel implemented on the host
|
||||
if (is_block_scaled && provider != library::Provider::kReferenceHost) {
|
||||
if ((is_block_scaled || is_blockwise) && provider != library::Provider::kReferenceHost) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1064,12 +1133,22 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
ptr_norm_constant = host_data_norm_constant.data();
|
||||
ws.norm_constant->copy_to_host(ptr_norm_constant);
|
||||
}
|
||||
else if (is_blockwise) {
|
||||
auto const& ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
host_data_SFA.resize(ws.SFA_ptr_array_host[group_idx]->bytes());
|
||||
ptr_SFA = host_data_SFA.data();
|
||||
ws.SFA_ptr_array_host[group_idx]->copy_to_host(ptr_SFA);
|
||||
host_data_SFB.resize(ws.SFB_ptr_array_host[group_idx]->bytes());
|
||||
ptr_SFB = host_data_SFB.data();
|
||||
ws.SFB_ptr_array_host[group_idx]->copy_to_host(ptr_SFB);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &desc = static_cast<library::GroupedGemmDescription const &>(operation->description());
|
||||
const auto& gemm_desc = desc.gemm;
|
||||
|
||||
if (!is_block_scaled) {
|
||||
if (!is_block_scaled and !is_blockwise) {
|
||||
library::Handle handle;
|
||||
handle.set_provider(provider);
|
||||
|
||||
@ -1112,7 +1191,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride());
|
||||
}
|
||||
else {
|
||||
else if (is_block_scaled) {
|
||||
auto const& block_scale_desc = desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
@ -1134,7 +1213,7 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFD.element,
|
||||
block_scale_desc.SFD.layout,
|
||||
block_scale_desc.SFVecSize,
|
||||
block_scale_desc.SFKVecSize,
|
||||
block_scale_desc.EpilogueSFVecSize);
|
||||
|
||||
auto operators_it =
|
||||
@ -1208,6 +1287,100 @@ bool GroupedGemmOperationProfiler::verify_with_reference_(
|
||||
|
||||
block_scale_ws.SFD_reference_ptr_array_host[group_idx]->copy_from_host(ptr_SFD);
|
||||
}
|
||||
else {
|
||||
// Blockwise
|
||||
auto const& block_scale_desc = desc.block_scales.value();
|
||||
auto& block_scale_ws = gemm_workspace_.block_scales.value();
|
||||
|
||||
library::BlockwiseGemmFunctionalKey blockwiseGemm_key(
|
||||
library::Provider::kReferenceHost,
|
||||
library::GemmKind::kUniversal,
|
||||
library::OperationKind::kBlockwiseGemm,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
element_A,
|
||||
gemm_desc.A.layout,
|
||||
block_scale_desc.SFA.element,
|
||||
element_B,
|
||||
gemm_desc.B.layout,
|
||||
block_scale_desc.SFB.element,
|
||||
gemm_desc.C.element,
|
||||
gemm_desc.C.layout,
|
||||
gemm_desc.D.element,
|
||||
gemm_desc.D.layout,
|
||||
block_scale_desc.SFMVecSize,
|
||||
block_scale_desc.SFNVecSize,
|
||||
block_scale_desc.SFKVecSize
|
||||
);
|
||||
|
||||
auto operators_it = library::Singleton::get().operation_table.blockwise_gemm_operations.find(blockwiseGemm_key);
|
||||
if (
|
||||
operators_it ==
|
||||
library::Singleton::get().operation_table.blockwise_gemm_operations.end()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
auto cc_it = operators_it->second.begin();
|
||||
if (cc_it == operators_it->second.end()) {
|
||||
disposition = Disposition::kNotSupported;
|
||||
break;
|
||||
}
|
||||
|
||||
// host reference has only one instances in BlockScaledOperationVectorMap
|
||||
library::Operation const* reference_op = cc_it->second[0];
|
||||
|
||||
library::BlockwiseGemmArguments arguments {
|
||||
{int(problem_.m(group_idx)), int(problem_.n(group_idx)), int(problem_.k(group_idx))},
|
||||
{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)},
|
||||
{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)},
|
||||
1, // batch_count
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
ptr_SFA,
|
||||
ptr_SFB,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
problem_.alpha.data(),
|
||||
problem_.beta.data(),
|
||||
library::ScalarPointerMode::kHost,
|
||||
problem_.lda[group_idx],
|
||||
problem_.ldb[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
gemm_workspace_.A_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.B_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(),
|
||||
gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride(),
|
||||
};
|
||||
|
||||
library::GemmUniversalConfiguration configuration{
|
||||
library::GemmUniversalMode::kGemm,
|
||||
problem_.problem_sizes[group_idx],
|
||||
{problem_.cluster_m, problem_.cluster_n, problem_.cluster_k},
|
||||
{problem_.cluster_m_fallback, problem_.cluster_n_fallback, problem_.cluster_k_fallback},
|
||||
1,
|
||||
problem_.lda[group_idx],
|
||||
problem_.ldb[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
problem_.ldc[group_idx],
|
||||
1,
|
||||
};
|
||||
uint64_t host_workspace_size_needed = reference_op->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||
std::vector<char> host_workspace(host_workspace_size_needed);
|
||||
status = reference_op->initialize(&configuration, host_workspace.data());
|
||||
if (status != Status::kSuccess) {
|
||||
break;
|
||||
}
|
||||
|
||||
status = reference_op->run(&arguments, host_workspace.data());
|
||||
}
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
break;
|
||||
}
|
||||
@ -1292,6 +1465,10 @@ Status GroupedGemmOperationProfiler::profile_cutlass_(
|
||||
void* device_workspace) {
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments);
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
return results_.back().status;
|
||||
}
|
||||
|
||||
auto func = [&](cudaStream_t stream, int iteration) {
|
||||
// Iterate over copies of the problem in memory
|
||||
|
||||
@ -301,6 +301,9 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
|
||||
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
|
||||
out << "kBlockScaledGemm";
|
||||
}
|
||||
else if (op_kind == library::OperationKind::kBlockwiseGemm) {
|
||||
out << "kBlockwiseGemm";
|
||||
}
|
||||
else if (op_kind == library::OperationKind::kRankK) {
|
||||
out << "kRankK";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user