Hopper Grouped GEMM support for FP8 Accum (#2123)

* Add support for fp8accum, with profiler extension

* Update .gitignore

* contri

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2025-02-20 18:55:26 -08:00
committed by GitHub
parent b84e9802d8
commit 9b3772dfa6
9 changed files with 914 additions and 71 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
# PyCache files
__pycache__/
cutlass_library.egg-info/
cutlass_library.egg-info/
/build*

View File

@ -1,17 +1,18 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS")
[README](./README.md#documentation) > **Active Developers**
[README](./README.md#documentation) > **Contributors**
# CUTLASS Developers **
Andrew Kerr (CUTLASS founding member)<br />
Andrew Kerr<br />
Paul Springer<br />
Dustyn Blasig<br />
Albert Xu<br />
Junkai Wu<br />
Xiuxia Zhang<br />
Haicheng Wu (CUTLASS founding member)<br />
Haicheng Wu<br />
Jack Yang<br />
Pradeep Ramani (CUTLASS 3.x founding member)<br />
Pradeep Ramani<br />
Aditya Atluri<br />
Han Li<br />
Nick Zhao<br />
@ -20,15 +21,15 @@ Yu-Jung Chen<br />
Markus Hoehnerbach<br />
Honghao Lu<br />
Mihir Awatramani<br />
Hao Sheng<br />
Hao Sheng<br />
Zekun Fan<br />
Aniket Shivam<br />
Aniket Shivam<br />
Siyu Liu<br />
Richard Cai<br />
Vikas Gupta<br />
Ethan Yan<br />
Vijay Thakkar (CUTLASS 3.x and CuTe founding member)<br />
Cris Cecka (CuTe and CUTLASS 3.x founding member)<br />
Vijay Thakkar<br />
Cris Cecka<br />
Lawrence Ryan<br />
Qun Song<br />
Daniel Ricketts<br />
@ -69,5 +70,61 @@ Shreya Gaur<br />
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
# CUTE Developers
Cris Cecka<br />
Vijay Thakkar<br />
# CUTLASS Product Manager
Matthew Nicely<br />
# Former CUTLASS Developers
Manish Gupta<br />
Duane Merrill<br />
Piotr Majcher<br />
Naila Farooqui<br />
Mark Hoemmen<br />
Rawn Henry<br />
Jin Wang<br />
Timmy Liu<br />
Manikandan Ananth<br />
David Tanner<br />
# Acknowledgements
Tri Dao<br />
Jay Shah<br />
Timothy Costa<br />
Julien Demouth<br />
Brian Fahs<br />
Michael Garland<br />
Michael Goldfarb<br />
Mostafa Hagog<br />
Fei Hu<br />
Alan Kaatz<br />
Tina Li<br />
Wei Liu<br />
Tim Martin<br />
Kevin Siu<br />
Markus Tavenrath<br />
John Tran<br />
Vicki Wang<br />
Fung Xie<br />
Yang Xu<br />
Scott Yokim<br />
Girish Bharambe<br />
Luke Durant<br />
Carter Edwards<br />
Olivier Giroux<br />
Stephen Jones<br />
Rishkul Kulkarni<br />
Bryce Lelbach<br />
Joel McCormack<br />
Kyrylo Perelygin<br />
Sean Treichler<br />

View File

@ -234,8 +234,6 @@ struct CollectiveBuilder<
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@ -267,12 +265,17 @@ struct CollectiveBuilder<
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
/* For FP8 use a separate mainloop compared to other datatypes */
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
>,
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
>
>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

View File

@ -46,8 +46,9 @@
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
#if !defined(__CUDACC_RTC__)

View File

@ -0,0 +1,768 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
#include "cutlass/numeric_types.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/tensor.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <
int Stages,
class ClusterShape,
class KernelSchedule,
class TileShape_,
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
using ElementB = ElementB_;
using StrideB = StrideB_;
using InternalStrideB = cute::remove_pointer_t<StrideB>;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
// One threads per CTA are producers (1 for operand tile)
static constexpr int NumProducerThreadEvents = 1;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
} tensors;
struct TensorMapStorage : cute::aligned_struct<128, _0> {
cute::TmaDescriptor smem_tensormap_A;
cute::TmaDescriptor smem_tensormap_B;
} tensormaps;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
// Host side kernel arguments
struct Arguments {
ElementA const** ptr_A;
StrideA dA;
ElementB const** ptr_B;
StrideB dB;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params {
TMA_A tma_load_a;
TMA_B tma_load_b;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
uint32_t mma_promotion_interval = 4;
void* tensormaps;
ElementA const** ptr_A;
StrideA dA;
ElementB const** ptr_B;
StrideB dB;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape problem_shapes,
Arguments const& args,
void* workspace) {
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_K = get<2>(init_shape);
auto init_L = get<3>(init_shape);
ElementA const* ptr_A_first_batch = reinterpret_cast<ElementA const*>(args.ptr_A);
ElementB const* ptr_B_first_batch = reinterpret_cast<ElementB const*>(args.ptr_B);
InternalStrideA stride_a;
InternalStrideB stride_b;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
stride_a = args.dA;
stride_b = args.dB;
}
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b));
TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
void* tensormaps = workspace;
return {
tma_load_a,
tma_load_b,
TmaTransactionBytes,
args.mma_promotion_interval,
tensormaps,
reinterpret_cast<ElementA const**>(args.ptr_A),
args.dA,
reinterpret_cast<ElementB const**>(args.ptr_B),
args.dB
};
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr uint32_t NumInputTensors = 2;
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
template<class ProblemShape>
static bool
can_implement(
ProblemShape problem_shapes,
Arguments const& args) {
constexpr int tma_alignment_bits = 128;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
bool implementable = true;
if (problem_shapes.is_host_problem_shape_available()) {
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
}
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value))+
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the contract that the
// returned tuple must contain at least two elements, with the first two elements being:
// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
const int32_t mock_L = 1;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
// Perform a collective-scoped matrix multiply-accumulate
// Producer Perspective
template <
class TensorA, class TensorB,
class TensorMapA, class TensorMapB,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB> const& load_inputs,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(
MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class FrgTensorC
>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read,
FrgTensorC& accum,
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
// Layout of warp group to thread mapping
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
stride<0>(typename TiledMma::BLayout{}) == 0 and
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation());
accumulation.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation.promote_residue_if_needed();
warpgroup_fence_operand(accumulation());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
//
// Methods to perform different parts of TMA/Tensormap modifications
//
CUTLASS_DEVICE auto
tensormaps_init(
Params const& mainloop_params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx) {
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to smem for modification later
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
}
__syncwarp();
return cute::make_tuple(tma_desc_a, tma_desc_b);
}
// Replace address for the global tensor (to be done by single thread)
CUTLASS_DEVICE
void
tensormaps_replace_global_address(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_batch) {
// Replacing global_address for the next batch
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
mainloop_params.ptr_A[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
mainloop_params.ptr_B[next_batch]);
}
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t K = get<2>(problem_shape_mnkl);
// Replace all dims for consistency
constexpr int MaxTensorRank = 5;
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
ElementA const* ptr_A = nullptr;
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
ElementB const* ptr_B = nullptr;
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a,
prob_shape_A, prob_stride_A);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
prob_shape_B, prob_stride_B);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride_A) {
stride = (stride * sizeof_bits_v<ElementA>) / 8;
}
for (uint64_t& stride : prob_stride_B) {
stride = (stride * sizeof_bits_v<ElementB>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
prob_shape_A,
prob_stride_A);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
prob_shape_B,
prob_stride_B);
}
template <class TensorMapA, class TensorMapB, class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_perform_update(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
ProblemShape_MNKL problem_shape_mnkl,
int32_t next_batch) {
if (cute::elect_one_sync()) {
// Replacing global_address for the next batch
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
if constexpr (IsGroupedGemmKernel) {
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties(shared_tensormaps,
mainloop_params, next_batch, problem_shape_mnkl);
}
}
}
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_cp_fence_release (
TensorMapStorage& shared_tensormaps,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
// Entire warp must do this (i.e. it's aligned)
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
}
// The entire warp must call this function collectively (that is, the instructions are aligned)
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -336,6 +336,21 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
// For FP8 kernels
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
>
struct MainloopSm90ArrayTmaGmmaWarpSpecializedFP8
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,

View File

@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
@ -551,10 +550,10 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
return schedule
group_schedule_map = {
# SM90
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
# SM100
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,

View File

@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# the following cases are unsupported by grouped GEMM
if not is_aligned:
return [], []
if not can_do_tma_epilogue:
return [], []
if requires_transposed_epilogue:
return [], []
@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
schedules = []
if not grouped:
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
schedules.append(
[
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
schedules.append(
[
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
return schedules, []
return [], []
@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return schedules, stream_k_schedules
if grouped:
pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
if can_do_tma_epilogue:
schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
if can_do_cooperative:
schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
return schedules, []
schedules = []
# Pruning: emit Void-C kernels with persistent kernels only
if level >= 1 or not is_void_c:
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
if (level >= 1 or not is_void_c) and not grouped:
# Pruning: don't stamp out fp8 kernels with auto schedule
if not is_fp8:
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecialized
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
# Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
if not is_fp8 or level >= 1:
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_fp8_fast_accum:
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
if not grouped:
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_cooperative:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
default_epilogue
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
default_epilogue
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
assert not requires_transposed_epilogue
if can_do_cooperative:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
EpilogueScheduleType.TmaWarpSpecializedCooperative
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
# Grouped GEMM do not support Stream-K scheduler
if grouped:
return schedules, []
return schedules, stream_k_schedules

View File

@ -204,9 +204,6 @@ protected:
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
return Status::kSuccess;
}
@ -215,6 +212,8 @@ protected:
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(arguments.alpha);
fusion_args.beta_ptr = static_cast<ElementCompute const*>(arguments.beta);
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
return Status::kSuccess;
}
else {