/*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* \file \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library */ #pragma once #include #include #include #include "cutlass/cutlass.h" #include "cutlass/library/library.h" #include "cutlass/library/manifest.h" #include "cutlass/library/util.h" #include "cutlass/util/packed_stride.hpp" #include "library_internal.h" #include "cutlass/util/reference/host/gett.hpp" #include "cutlass/detail/blockwise_scale_layout.hpp" /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace library { /////////////////////////////////////////////////////////////////////////////////////////////////// template < Provider Provider_, typename ElementA_, typename LayoutA_, typename LayoutSFA_, typename ElementSFA_, typename ElementB_, typename LayoutB_, typename LayoutSFB_, typename ElementSFB_, typename ElementC_, typename LayoutC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, typename ElementD_ = ElementC_, typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > class BlockwiseGemmReferenceOperation : public Operation { public: static Provider const kProvider = Provider_; using ElementA = ElementA_; using LayoutA = LayoutA_; using ElementSFA = ElementSFA_; using ElementB = ElementB_; using LayoutB = LayoutB_; using ElementSFB = ElementSFB_; using ElementC = ElementC_; using LayoutC = LayoutC_; using ElementD = ElementD_; using ElementCompute = ElementCompute_; using ElementAccumulator = ElementAccumulator_; using ConvertOp = ConvertOp_; using InnerProductOp = InnerProductOp_; protected: /// Storage for the name string std::string name_; /// BlockwiseGemmDescription description_; public: /// Constructor BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { // Basic information description_.provider = kProvider; description_.kind = OperationKind::kBlockwiseGemm; description_.gemm_kind = GemmKind::kUniversal; // Tensor description description_.A = make_TensorDescription(); description_.SFA = make_TensorDescription(); description_.B = make_TensorDescription(); description_.SFB = make_TensorDescription(); description_.C = make_TensorDescription(); description_.D = make_TensorDescription(); // Epilogue compute and accumulator type description description_.element_epilogue = NumericTypeMap::kId; description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; // Compute capability for gemm reference description_.tile_description.minimum_compute_capability = (kProvider == Provider::kReferenceDevice ? 50 : 0); description_.tile_description.maximum_compute_capability = 1024; description_.SFMVecSize = SFMVecSize; description_.SFNVecSize = SFNVecSize; description_.SFKVecSize = SFKVecSize; // Procedural name std::stringstream ss; ss << "gemm" << "_reference_" << to_string(description_.provider) << "_" << to_string(description_.A.element) << to_string(description_.A.layout) << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) << "_" << to_string(description_.B.element) << to_string(description_.B.layout) << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) << "_" << to_string(description_.C.element) << to_string(description_.C.layout) << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); name_ = ss.str(); description_.name = name_.c_str(); // Epilogue compute and accumulator type description description_.element_epilogue = NumericTypeMap::kId; description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; } /// Returns the description of the GEMM operation virtual OperationDescription const & description() const { return description_; } virtual Status can_implement( void const *configuration, void const *arguments) const { return Status::kSuccess; } virtual uint64_t get_host_workspace_size( void const *configuration) const { return sizeof(GemmUniversalConfiguration); } virtual uint64_t get_device_workspace_size( void const *configuration, void const *arguments = nullptr) const { return 0; } virtual Status initialize( void const *configuration, void *host_workspace, void *device_workspace = nullptr, cudaStream_t stream = nullptr) const { return Status::kSuccess; } virtual Status run( void const *arguments, void *host_workspace, void *device_workspace = nullptr, cudaStream_t stream = nullptr) const { using namespace cute; BlockwiseGemmArguments const &args = *static_cast(arguments); // Construct cute::Tensor A/B/C int M = args.problem_size.m(); int N = args.problem_size.n(); int K = args.problem_size.k(); int L = args.batch_count; auto problem_shape_MNKL = cute::make_shape(M, N, K, L); auto alpha = *(static_cast(args.alpha)); auto beta = *(static_cast(args.beta)); using StrideA = cutlass::gemm::TagToStrideA_t; using StrideB = cutlass::gemm::TagToStrideB_t; using StrideC = cutlass::gemm::TagToStrideC_t; using StrideD = cutlass::gemm::TagToStrideC_t; auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; auto A = cute::make_tensor(static_cast(args.A), cute::make_layout(cute::make_shape(M, K, L), stride_a)); auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); auto B = cute::make_tensor(static_cast(args.B), cute::make_layout(cute::make_shape(N, K, L), stride_b)); auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); auto C = [&]() { if constexpr (not is_same_v) { return cute::make_tensor(static_cast(args.C), cute::make_layout(cute::make_shape(M, N, L), stride_c)); } else { return cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, N, L), stride_c)); } }(); auto D = cute::make_tensor(static_cast(args.D), cute::make_layout(cute::make_shape(M, N, L), stride_d)); cutlass::reference::host::GettBlockScalingMainloopParams mainloop_params{A, SfA, B, SfB}; // W/O SF generation cutlass::reference::host::GettEpilogueParams< ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, decltype(C), decltype(D)> epilogue_params{alpha, beta, C, D}; cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); return Status::kSuccess; } private: int SFMVecSize; int SFNVecSize; int SFKVecSize; }; /////////////////////////////////////////////////////////////////////////////////////////////////// template < typename ElementA_, typename ElementSFA_, typename ElementB_, typename ElementSFB_, typename ElementC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, typename ElementD_ = ElementC_, typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::RowMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::RowMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); manifest.append(new BlockwiseGemmReferenceOperation< Provider::kReferenceHost, ElementA_, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, ElementSFA_, ElementB_, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, ElementSFB_, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, ElementD_, ConvertOp_, InnerProductOp_ >(SFMVecSize, SFNVecSize, SFKVecSize)); } template void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1 , 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); } /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////////