Files
cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h
2025-07-21 22:03:55 -04:00

808 lines
27 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library
*/
#pragma once
#include <iostream>
#include <sstream>
#include <cstring>
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "cutlass/util/packed_stride.hpp"
#include "library_internal.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/detail/blockwise_scale_layout.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
Provider Provider_,
typename ElementA_,
typename LayoutA_,
typename LayoutSFA_,
typename ElementSFA_,
typename ElementB_,
typename LayoutB_,
typename LayoutSFB_,
typename ElementSFB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
class BlockwiseGemmReferenceOperation : public Operation {
public:
static Provider const kProvider = Provider_;
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using ElementSFA = ElementSFA_;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using ElementSFB = ElementSFB_;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using ElementD = ElementD_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementAccumulator_;
using ConvertOp = ConvertOp_;
using InnerProductOp = InnerProductOp_;
protected:
/// Storage for the name string
std::string name_;
///
BlockwiseGemmDescription description_;
public:
/// Constructor
BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_)
: SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) {
// Basic information
description_.provider = kProvider;
description_.kind = OperationKind::kBlockwiseGemm;
description_.gemm_kind = GemmKind::kUniversal;
// Tensor description
description_.A = make_TensorDescription<ElementA, LayoutA>();
description_.SFA = make_TensorDescription<ElementSFA, LayoutSFA_>();
description_.B = make_TensorDescription<ElementB, LayoutB>();
description_.SFB = make_TensorDescription<ElementSFB, LayoutSFB_>();
description_.C = make_TensorDescription<ElementC, LayoutC>();
description_.D = make_TensorDescription<ElementD, LayoutC>();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
// Compute capability for gemm reference
description_.tile_description.minimum_compute_capability =
(kProvider == Provider::kReferenceDevice ? 50 : 0);
description_.tile_description.maximum_compute_capability = 1024;
description_.SFMVecSize = SFMVecSize;
description_.SFNVecSize = SFNVecSize;
description_.SFKVecSize = SFKVecSize;
// Procedural name
std::stringstream ss;
ss << "gemm"
<< "_reference_" << to_string(description_.provider)
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
<< "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout)
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
<< "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout)
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
name_ = ss.str();
description_.name = name_.c_str();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
virtual Status can_implement(
void const *configuration,
void const *arguments) const {
return Status::kSuccess;
}
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(GemmUniversalConfiguration);
}
virtual uint64_t get_device_workspace_size(
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
return Status::kSuccess;
}
virtual Status run(
void const *arguments,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
using namespace cute;
BlockwiseGemmArguments const &args = *static_cast<BlockwiseGemmArguments const *>(arguments);
// Construct cute::Tensor A/B/C
int M = args.problem_size.m();
int N = args.problem_size.n();
int K = args.problem_size.k();
int L = args.batch_count;
auto problem_shape_MNKL = cute::make_shape(M, N, K, L);
auto alpha = *(static_cast<ElementCompute const*>(args.alpha));
auto beta = *(static_cast<ElementCompute const*>(args.beta));
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutC>;
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>;
auto A = cute::make_tensor(static_cast<ElementA const*>(args.A),
cute::make_layout(cute::make_shape(M, K, L), stride_a));
auto SfA = make_tensor(static_cast<ElementSFA const*>(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
auto B = cute::make_tensor(static_cast<ElementB const*>(args.B),
cute::make_layout(cute::make_shape(N, K, L), stride_b));
auto SfB = make_tensor(static_cast<ElementSFB const*>(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize)));
auto C = [&]() {
if constexpr (not is_same_v<ElementC, void>) {
return cute::make_tensor(static_cast<ElementC const*>(args.C),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
else {
return cute::make_tensor(static_cast<ElementD const*>(nullptr),
cute::make_layout(cute::make_shape(M, N, L), stride_c));
}
}();
auto D = cute::make_tensor(static_cast<ElementD *>(args.D),
cute::make_layout(cute::make_shape(M, N, L), stride_d));
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
decltype(A), decltype(SfA),
decltype(B), decltype(SfB)>
mainloop_params{A, SfA, B, SfB};
// W/O SF generation
cutlass::reference::host::GettEpilogueParams<
ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute,
decltype(C), decltype(D)>
epilogue_params{alpha, beta, C, D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
return Status::kSuccess;
}
private:
int SFMVecSize;
int SFNVecSize;
int SFKVecSize;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_,
typename ElementSFA_,
typename ElementB_,
typename ElementSFB_,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ElementD_ = ElementC_,
typename ConvertOp_ = NumericConverter<ElementD_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) {
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::RowMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
manifest.append(new BlockwiseGemmReferenceOperation<
Provider::kReferenceHost,
ElementA_,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
ElementSFA_,
ElementB_,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
ElementSFB_,
ElementC_,
cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ElementD_,
ConvertOp_,
InnerProductOp_
>(SFMVecSize, SFNVecSize, SFKVecSize));
}
template<class ElementC,
class ElementD>
void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) {
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 1, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 1, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 32, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 32, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 64, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 64, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 256, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 256, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 1, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 1 , 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 128, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 32, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 32, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 64, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 64, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 256, 128);
make_blockwise_gemm<
float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 256, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 1, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 1, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 32, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 32, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 64, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 64, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 256, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 256, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 1 , 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 1, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 1 , 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 64, 128, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 32, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 32, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 64, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 64, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 128, 256, 128);
make_blockwise_gemm<
float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/,
ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/
>(manifest, 1, 256, 128);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////