diff --git a/.gitignore b/.gitignore
index acddb1f9..e7a02687 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
# PyCache files
__pycache__/
-cutlass_library.egg-info/
\ No newline at end of file
+cutlass_library.egg-info/
+/build*
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 1ef06a36..843ed365 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -1,17 +1,18 @@

-[README](./README.md#documentation) > **Active Developers**
+[README](./README.md#documentation) > **Contributors**
# CUTLASS Developers **
-Andrew Kerr (CUTLASS founding member)
+Andrew Kerr
+Paul Springer
Dustyn Blasig
Albert Xu
Junkai Wu
Xiuxia Zhang
-Haicheng Wu (CUTLASS founding member)
+Haicheng Wu
Jack Yang
-Pradeep Ramani (CUTLASS 3.x founding member)
+Pradeep Ramani
Aditya Atluri
Han Li
Nick Zhao
@@ -20,15 +21,15 @@ Yu-Jung Chen
Markus Hoehnerbach
Honghao Lu
Mihir Awatramani
-Hao Sheng
+Hao Sheng
Zekun Fan
-Aniket Shivam
+Aniket Shivam
Siyu Liu
Richard Cai
Vikas Gupta
Ethan Yan
-Vijay Thakkar (CUTLASS 3.x and CuTe founding member)
-Cris Cecka (CuTe and CUTLASS 3.x founding member)
+Vijay Thakkar
+Cris Cecka
Lawrence Ryan
Qun Song
Daniel Ricketts
@@ -69,5 +70,61 @@ Shreya Gaur
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
+
+# CUTE Developers
+
+Cris Cecka
+Vijay Thakkar
+
+
# CUTLASS Product Manager
+
Matthew Nicely
+
+
+# Former CUTLASS Developers
+
+Manish Gupta
+Duane Merrill
+Piotr Majcher
+Naila Farooqui
+Mark Hoemmen
+Rawn Henry
+Jin Wang
+Timmy Liu
+Manikandan Ananth
+David Tanner
+
+
+# Acknowledgements
+
+Tri Dao
+Jay Shah
+Timothy Costa
+Julien Demouth
+Brian Fahs
+Michael Garland
+Michael Goldfarb
+Mostafa Hagog
+Fei Hu
+Alan Kaatz
+Tina Li
+Wei Liu
+Tim Martin
+Kevin Siu
+Markus Tavenrath
+John Tran
+Vicki Wang
+Fung Xie
+Yang Xu
+Scott Yokim
+Girish Bharambe
+Luke Durant
+Carter Edwards
+Olivier Giroux
+Stephen Jones
+Rishkul Kulkarni
+Bryce Lelbach
+Joel McCormack
+Kyrylo Perelygin
+Sean Treichler
diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
index 7736dbee..504abcc6 100644
--- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
+++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
@@ -234,8 +234,6 @@ struct CollectiveBuilder<
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8();
- static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
- "KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>;
@@ -267,12 +265,17 @@ struct CollectiveBuilder<
static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{});
+ /* For FP8 use a separate mainloop compared to other datatypes */
using DispatchPolicy = cute::conditional_t,
- /* For FP8 use a separate mainloop compared to other datatypes */
+ cute::conditional_t,
+ MainloopSm90ArrayTmaGmmaWarpSpecialized
+ >,
cute::conditional_t,
- MainloopSm90TmaGmmaWarpSpecialized>>;
+ MainloopSm90TmaGmmaWarpSpecialized
+ >
+ >;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp
index f6a0dcb3..e33c06a7 100644
--- a/include/cutlass/gemm/collective/collective_mma.hpp
+++ b/include/cutlass/gemm/collective/collective_mma.hpp
@@ -46,8 +46,9 @@
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
-#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
+#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
+#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
#if !defined(__CUDACC_RTC__)
diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp
new file mode 100644
index 00000000..b3d857ff
--- /dev/null
+++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp
@@ -0,0 +1,768 @@
+/***************************************************************************************************
+ * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/collective/fp8_accumulation.hpp"
+#include "cutlass/trace.h"
+#include "cutlass/numeric_types.h"
+
+#include "cute/arch/cluster_sm90.hpp"
+#include "cute/arch/copy_sm90.hpp"
+#include "cute/algorithm/functional.hpp"
+#include "cute/atom/mma_atom.hpp"
+#include "cute/algorithm/gemm.hpp"
+#include "cute/tensor_predicate.hpp"
+#include "cute/tensor.hpp"
+#include "cute/numeric/arithmetic_tuple.hpp"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::gemm::collective {
+using namespace cute;
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// WarpSpecialized Mainloop
+template <
+ int Stages,
+ class ClusterShape,
+ class KernelSchedule,
+ class TileShape_,
+ class ElementA_,
+ class StrideA_,
+ class ElementB_,
+ class StrideB_,
+ class TiledMma_,
+ class GmemTiledCopyA_,
+ class SmemLayoutAtomA_,
+ class SmemCopyAtomA_,
+ class TransformA_,
+ class GmemTiledCopyB_,
+ class SmemLayoutAtomB_,
+ class SmemCopyAtomB_,
+ class TransformB_>
+struct CollectiveMma<
+ MainloopSm90ArrayTmaGmmaWarpSpecializedFP8,
+ TileShape_,
+ ElementA_,
+ StrideA_,
+ ElementB_,
+ StrideB_,
+ TiledMma_,
+ GmemTiledCopyA_,
+ SmemLayoutAtomA_,
+ SmemCopyAtomA_,
+ TransformA_,
+ GmemTiledCopyB_,
+ SmemLayoutAtomB_,
+ SmemCopyAtomB_,
+ TransformB_>
+{
+ //
+ // Type Aliases
+ //
+ using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedFP8;
+ using TileShape = TileShape_;
+ using ElementA = ElementA_;
+ using StrideA = StrideA_;
+ using InternalStrideA = cute::remove_pointer_t;
+ using ElementB = ElementB_;
+ using StrideB = StrideB_;
+ using InternalStrideB = cute::remove_pointer_t;
+ using TiledMma = TiledMma_;
+ using ElementAccumulator = typename TiledMma::ValTypeC;
+ using GmemTiledCopyA = GmemTiledCopyA_;
+ using GmemTiledCopyB = GmemTiledCopyB_;
+ using SmemLayoutAtomA = SmemLayoutAtomA_;
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
+ using SmemCopyAtomA = SmemCopyAtomA_;
+ using SmemCopyAtomB = SmemCopyAtomB_;
+ using TransformA = TransformA_;
+ using TransformB = TransformB_;
+ using ArchTag = typename DispatchPolicy::ArchTag;
+
+ using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
+ using MainloopPipeline = cutlass::PipelineTmaAsync;
+ using PipelineState = cutlass::PipelineState;
+
+ using PipelineParams = typename MainloopPipeline::Params;
+
+ // One threads per CTA are producers (1 for operand tile)
+ static constexpr int NumProducerThreadEvents = 1;
+
+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
+
+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
+
+ // Tile along modes in a way that maximizes the TMA box size.
+ using SmemLayoutA = decltype(tile_to_shape(
+ SmemLayoutAtomA{},
+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}),
+ cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
+ using SmemLayoutB = decltype(tile_to_shape(
+ SmemLayoutAtomB{},
+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}),
+ cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
+
+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
+ static_assert(cute::is_base_of::value &&
+ cute::is_base_of::value,
+ "MMA atom must source both A and B operand from smem_desc for this mainloop.");
+ static_assert(cute::is_same_v || cute::is_same_v,
+ "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
+ static_assert(cute::is_same_v || cute::is_same_v,
+ "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
+
+ // Assumption: StrideA is congruent with Problem_MK
+ using TMA_A = decltype(make_tma_copy(
+ GmemTiledCopyA{},
+ make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
+ SmemLayoutA{}(_,_,cute::Int<0>{}),
+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
+ size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
+ // Assumption: StrideB is congruent with Problem_NK
+ using TMA_B = decltype(make_tma_copy(
+ GmemTiledCopyB{},
+ make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
+ SmemLayoutB{}(_,_,cute::Int<0>{}),
+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
+ size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
+
+ struct SharedStorage {
+ struct TensorStorage : cute::aligned_struct<128, _0> {
+ cute::array_aligned> smem_A;
+ cute::array_aligned> smem_B;
+ } tensors;
+
+ struct TensorMapStorage : cute::aligned_struct<128, _0> {
+ cute::TmaDescriptor smem_tensormap_A;
+ cute::TmaDescriptor smem_tensormap_B;
+ } tensormaps;
+
+ using PipelineStorage = typename MainloopPipeline::SharedStorage;
+ PipelineStorage pipeline;
+ };
+ using TensorStorage = typename SharedStorage::TensorStorage;
+ using TensorMapStorage = typename SharedStorage::TensorMapStorage;
+ using PipelineStorage = typename SharedStorage::PipelineStorage;
+
+ static constexpr bool IsGroupedGemmKernel = !cute::is_same_v;
+
+ // Host side kernel arguments
+ struct Arguments {
+ ElementA const** ptr_A;
+ StrideA dA;
+ ElementB const** ptr_B;
+ StrideB dB;
+ uint32_t mma_promotion_interval = 4;
+ };
+
+ // Device side kernel params
+ struct Params {
+ TMA_A tma_load_a;
+ TMA_B tma_load_b;
+ uint32_t tma_transaction_bytes = TmaTransactionBytes;
+ uint32_t mma_promotion_interval = 4;
+ void* tensormaps;
+ ElementA const** ptr_A;
+ StrideA dA;
+ ElementB const** ptr_B;
+ StrideB dB;
+ };
+
+ //
+ // Methods
+ //
+
+ template
+ static constexpr Params
+ to_underlying_arguments(
+ ProblemShape problem_shapes,
+ Arguments const& args,
+ void* workspace) {
+ // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
+ // These will be replaced with correct values before the initial tma load.
+ auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
+ auto init_M = get<0>(init_shape);
+ auto init_N = get<1>(init_shape);
+ auto init_K = get<2>(init_shape);
+ auto init_L = get<3>(init_shape);
+
+ ElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A);
+ ElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B);
+
+ InternalStrideA stride_a;
+ InternalStrideB stride_b;
+ if constexpr (IsGroupedGemmKernel) {
+ // Strides for Grouped Gemm will be replaced prior to the first access regardless.
+ stride_a = InternalStrideA{};
+ stride_b = InternalStrideB{};
+ }
+ else {
+ // Tensor shapes for Ptr-Array are initialized correctly only here.
+ auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
+ init_M = get<0>(problem_shape_MNK);
+ init_N = get<1>(problem_shape_MNK);
+ init_K = get<2>(problem_shape_MNK);
+
+ stride_a = args.dA;
+ stride_b = args.dB;
+ }
+ Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a));
+ Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b));
+ TMA_A tma_load_a = make_tma_copy(
+ GmemTiledCopyA{},
+ tensor_a,
+ SmemLayoutA{}(_,_,cute::Int<0>{}),
+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
+ size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
+ TMA_B tma_load_b = make_tma_copy(
+ GmemTiledCopyB{},
+ tensor_b,
+ SmemLayoutB{}(_,_,cute::Int<0>{}),
+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
+ size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
+
+ void* tensormaps = workspace;
+
+ return {
+ tma_load_a,
+ tma_load_b,
+ TmaTransactionBytes,
+ args.mma_promotion_interval,
+ tensormaps,
+ reinterpret_cast(args.ptr_A),
+ args.dA,
+ reinterpret_cast(args.ptr_B),
+ args.dB
+ };
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
+ constexpr uint32_t NumInputTensors = 2;
+ constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
+ // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
+ return (NumInputTensors * SizeOfCuTensorMap * sm_count);
+ }
+
+ template
+ static cutlass::Status
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) {
+ return cutlass::Status::kSuccess;
+ }
+
+ template
+ static bool
+ can_implement(
+ ProblemShape problem_shapes,
+ Arguments const& args) {
+ constexpr int tma_alignment_bits = 128;
+ constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value;
+ constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value;
+
+ bool implementable = true;
+ if (problem_shapes.is_host_problem_shape_available()) {
+ // Check alignment for all problem sizes
+ for (int i = 0; i < problem_shapes.groups(); i++) {
+ auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
+ auto [M,N,K,L] = problem_shape_MNKL;
+ implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{});
+ implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{});
+ }
+ }
+
+ if (!implementable) {
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
+ }
+ return implementable;
+ }
+
+ static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
+ static constexpr int K_PIPE_MMAS = 1;
+ static constexpr uint32_t TmaTransactionBytes =
+ cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+
+ cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value));
+
+ // Set up the data needed by this collective for load and mma.
+ // Returns a tuple of tensors. The collective and the kernel layer have the contract that the
+ // returned tuple must contain at least two elements, with the first two elements being:
+ // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
+ // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
+ // The rest of the tensors can be specified as needed by this collective.
+ template
+ CUTLASS_DEVICE auto
+ load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
+ using X = Underscore;
+ // Separate out problem shape for convenience
+ auto [M,N,K,L] = problem_shape_MNKL;
+ const int32_t mock_L = 1;
+
+ // TMA requires special handling of strides to deal with coord codomain mapping
+ // Represent the full tensors -- get these from TMA
+ Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l)
+ Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l)
+
+ // Make tiled views, defer the slice
+ Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
+ Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
+
+ return cute::make_tuple(gA_mkl, gB_nkl);
+ }
+
+ // Perform a collective-scoped matrix multiply-accumulate
+ // Producer Perspective
+ template <
+ class TensorA, class TensorB,
+ class TensorMapA, class TensorMapB,
+ class KTileIterator, class BlockCoord
+ >
+ CUTLASS_DEVICE void
+ load(
+ Params const& mainloop_params,
+ MainloopPipeline pipeline,
+ PipelineState smem_pipe_write,
+ cute::tuple const& load_inputs,
+ cute::tuple const& input_tensormaps,
+ BlockCoord const& blk_coord,
+ KTileIterator k_tile_iter, int k_tile_count,
+ int thread_idx,
+ uint32_t block_rank_in_cluster,
+ TensorStorage& shared_tensors) {
+ int lane_predicate = cute::elect_one_sync();
+
+ if (lane_predicate) {
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
+
+ //
+ // Prepare the TMA loads for A and B
+ //
+
+ constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
+ uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
+
+ Tensor gA_mkl = get<0>(load_inputs);
+ Tensor gB_nkl = get<1>(load_inputs);
+
+ auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
+ auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
+
+ // Partition the inputs based on the current block coordinates.
+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
+
+ // Applies the mapping from block_tma_a
+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
+
+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
+
+ uint16_t mcast_mask_a = 0;
+ uint16_t mcast_mask_b = 0;
+
+ // Issue TmaLoads
+ // Maps the tile -> block, value
+ if constexpr (cute::is_same_v) {
+ auto block_layout = Layout{}; // (m,n) -> block_id
+ for (int n = 0; n < size<1>(block_layout); ++n) {
+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
+ }
+ }
+
+ if constexpr (cute::is_same_v) {
+ auto block_layout = Layout{}; // (m,n) -> block_id
+ for (int m = 0; m < size<0>(block_layout); ++m) {
+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
+ }
+ }
+
+ // Mainloop
+ CUTLASS_PRAGMA_NO_UNROLL
+ for ( ; k_tile_count > 0; --k_tile_count) {
+ // LOCK smem_pipe_write for _writing_
+ pipeline.producer_acquire(smem_pipe_write);
+
+ //
+ // Copy gmem to smem for *k_tile_iter
+ //
+
+ using BarrierType = typename MainloopPipeline::ProducerBarrierType;
+ BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
+
+ int write_stage = smem_pipe_write.index();
+ copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
+ copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
+ ++k_tile_iter;
+
+ // Advance smem_pipe_write
+ ++smem_pipe_write;
+ }
+ }
+ }
+
+ /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
+ CUTLASS_DEVICE void
+ load_tail(
+ MainloopPipeline pipeline,
+ PipelineState smem_pipe_write) {
+ int lane_predicate = cute::elect_one_sync();
+
+ // Issue the epilogue waits
+ if (lane_predicate) {
+ /* This helps avoid early exit of blocks in Cluster
+ * Waits for all stages to either be released (all
+ * Consumer UNLOCKs), or if the stage was never used
+ * then would just be acquired since the phase was
+ * still inverted from make_producer_start_state
+ */
+ pipeline.producer_tail(smem_pipe_write);
+ }
+ }
+
+ /// Perform a collective-scoped matrix multiply-accumulate
+ /// Consumer Perspective
+ template <
+ class FrgTensorC
+ >
+ CUTLASS_DEVICE void
+ mma(MainloopPipeline pipeline,
+ PipelineState smem_pipe_read,
+ FrgTensorC& accum,
+ int k_tile_count,
+ int thread_idx,
+ TensorStorage& shared_tensors,
+ Params const& mainloop_params) {
+
+ static_assert(is_rmem::value, "C tensor must be rmem resident.");
+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
+ static_assert(cute::is_void_v,
+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
+ static_assert(cute::is_void_v,
+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
+
+ Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
+ Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
+
+ //
+ // Define C accumulators and A/B partitioning
+ //
+
+ // Layout of warp group to thread mapping
+
+ static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
+ stride<0>(typename TiledMma::BLayout{}) == 0 and
+ size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
+ size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
+ "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
+
+ constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
+ Layout warp_group_thread_layout = make_layout(Int{},
+ Int{});
+
+ int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
+
+ TiledMma tiled_mma;
+ auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
+
+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
+
+ // Allocate "fragments/descriptors"
+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
+
+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE
+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE
+
+ //
+ // PIPELINED MAIN LOOP
+ //
+ static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
+ "ERROR : Incorrect number of MMAs in flight");
+
+ // We release buffers to producer warps(dma load) with some mmas in flight
+ PipelineState smem_pipe_release = smem_pipe_read;
+
+ // Prologue GMMAs
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
+
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
+
+ GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
+ warpgroup_fence_operand(accumulation());
+ CUTLASS_PRAGMA_UNROLL
+ for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
+ {
+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
+ auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
+ pipeline.consumer_wait(smem_pipe_read, barrier_token);
+
+ if (accumulation.prepare_if_needed()) {
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
+ }
+
+ int read_stage = smem_pipe_read.index();
+ warpgroup_arrive();
+ // Unroll the K mode manually to set scale D to 1
+ CUTLASS_PRAGMA_UNROLL
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
+ // (V,M,K) x (V,N,K) => (V,M,N)
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ }
+ warpgroup_commit_batch();
+
+ accumulation.promote_if_needed();
+
+ ++smem_pipe_read;
+ }
+
+ warpgroup_fence_operand(accumulation());
+ // Mainloop GMMAs
+ k_tile_count -= prologue_mma_count;
+
+ CUTLASS_PRAGMA_NO_UNROLL
+ for ( ; k_tile_count > 0; --k_tile_count)
+ {
+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
+ auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
+ pipeline.consumer_wait(smem_pipe_read, barrier_token);
+
+ //
+ // Compute on k_tile
+ //
+
+ int read_stage = smem_pipe_read.index();
+
+ if (accumulation.prepare_if_needed()) {
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
+ }
+
+ warpgroup_fence_operand(accumulation());
+ warpgroup_arrive();
+ // Unroll the K mode manually to set scale D to 1
+ CUTLASS_PRAGMA_UNROLL
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
+ // (V,M,K) x (V,N,K) => (V,M,N)
+ cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ }
+ warpgroup_commit_batch();
+
+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
+ warpgroup_wait();
+ warpgroup_fence_operand(accumulation());
+
+ accumulation.promote_if_needed();
+
+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
+
+ // Advance smem_pipe_read and smem_pipe_release
+ ++smem_pipe_read;
+ ++smem_pipe_release;
+ }
+
+ accumulation.promote_residue_if_needed();
+
+ warpgroup_fence_operand(accumulation());
+ }
+
+ /// Perform a Consumer Epilogue to release all buffers
+ CUTLASS_DEVICE void
+ mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
+ // Prologue GMMAs
+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
+ k_tile_count -= prologue_mma_count;
+
+ smem_pipe_release.advance(k_tile_count);
+
+ // Wait on all GMMAs to complete
+ warpgroup_wait<0>();
+
+ for (int count = 0; count < prologue_mma_count; ++count) {
+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
+ ++smem_pipe_release;
+ }
+ }
+
+ //
+ // Methods to perform different parts of TMA/Tensormap modifications
+ //
+
+ CUTLASS_DEVICE auto
+ tensormaps_init(
+ Params const& mainloop_params,
+ TensorMapStorage& shared_tensormaps,
+ int32_t sm_count,
+ int32_t sm_idx) {
+ cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps);
+
+ cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
+ cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
+
+ if (cute::elect_one_sync()) {
+ // Bringing tensormaps from params to smem for modification later
+ Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
+ Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
+ Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
+ Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
+
+ copy(recast(pA_tensormap), recast(sA_tensormap));
+ copy(recast(pB_tensormap), recast(sB_tensormap));
+ }
+ __syncwarp();
+
+ return cute::make_tuple(tma_desc_a, tma_desc_b);
+ }
+
+ // Replace address for the global tensor (to be done by single thread)
+ CUTLASS_DEVICE
+ void
+ tensormaps_replace_global_address(
+ TensorMapStorage& shared_tensormaps,
+ Params const& mainloop_params,
+ int32_t next_batch) {
+ // Replacing global_address for the next batch
+ cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
+ mainloop_params.ptr_A[next_batch]);
+ cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
+ mainloop_params.ptr_B[next_batch]);
+ }
+
+ // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
+ template
+ CUTLASS_DEVICE
+ void
+ tensormaps_replace_global_tensor_properties(
+ TensorMapStorage& shared_tensormaps,
+ Params const& mainloop_params,
+ int32_t next_group,
+ ProblemShape_MNKL problem_shape_mnkl) {
+ const uint32_t M = get<0>(problem_shape_mnkl);
+ const uint32_t N = get<1>(problem_shape_mnkl);
+ const uint32_t K = get<2>(problem_shape_mnkl);
+ // Replace all dims for consistency
+ constexpr int MaxTensorRank = 5;
+ cute::array prob_shape_A = {1,1,1,1,1};
+ cute::array prob_stride_A = {0,0,0,0,0};
+ cute::array prob_shape_B = {1,1,1,1,1};
+ cute::array prob_stride_B = {0,0,0,0,0};
+
+ ElementA const* ptr_A = nullptr;
+ Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
+
+ ElementB const* ptr_B = nullptr;
+ Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
+
+ cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a,
+ prob_shape_A, prob_stride_A);
+ cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
+ prob_shape_B, prob_stride_B);
+
+ // Convert strides to byte strides
+ for (uint64_t& stride : prob_stride_A) {
+ stride = (stride * sizeof_bits_v) / 8;
+ }
+ for (uint64_t& stride : prob_stride_B) {
+ stride = (stride * sizeof_bits_v) / 8;
+ }
+
+ cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
+ prob_shape_A,
+ prob_stride_A);
+ cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
+ prob_shape_B,
+ prob_stride_B);
+ }
+
+ template
+ CUTLASS_DEVICE
+ void
+ tensormaps_perform_update(
+ TensorMapStorage& shared_tensormaps,
+ Params const& mainloop_params,
+ cute::tuple const& input_tensormaps,
+ ProblemShape_MNKL problem_shape_mnkl,
+ int32_t next_batch) {
+ if (cute::elect_one_sync()) {
+ // Replacing global_address for the next batch
+ tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
+
+ if constexpr (IsGroupedGemmKernel) {
+ // Replacing global dims and strides for the next batch
+ tensormaps_replace_global_tensor_properties(shared_tensormaps,
+ mainloop_params, next_batch, problem_shape_mnkl);
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE
+ void
+ tensormaps_cp_fence_release (
+ TensorMapStorage& shared_tensormaps,
+ cute::tuple const& input_tensormaps) {
+ // Entire warp must do this (i.e. it's aligned)
+ tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
+ tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
+ }
+
+ // The entire warp must call this function collectively (that is, the instructions are aligned)
+ template
+ CUTLASS_DEVICE
+ void
+ tensormaps_fence_acquire(cute::tuple const& input_tensormaps) {
+ cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
+ cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass::gemm::collective
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp
index 8747f48b..0936eb25 100644
--- a/include/cutlass/gemm/dispatch_policy.hpp
+++ b/include/cutlass/gemm/dispatch_policy.hpp
@@ -336,6 +336,21 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
};
+// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
+// For FP8 kernels
+template<
+ int Stages_,
+ class ClusterShape_ = Shape<_1,_1,_1>,
+ class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
+>
+struct MainloopSm90ArrayTmaGmmaWarpSpecializedFP8
+ : MainloopSm90ArrayTmaGmmaWarpSpecialized {
+ static_assert(
+ cute::is_base_of_v ||
+ cute::is_base_of_v,
+ "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
+};
+
// n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,
diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py
index bc2cc7b1..d5f02db0 100644
--- a/python/cutlass_library/library.py
+++ b/python/cutlass_library/library.py
@@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
+ PtrArrayTmaWarpSpecializedCooperative = enum_auto()
+ PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
+ PtrArrayTmaWarpSpecializedPingpong = enum_auto()
+ PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
TmaWarpSpecialized1SmSm100 = enum_auto()
TmaWarpSpecialized2SmSm100 = enum_auto()
@@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
- KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
- KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
- KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
- KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
-
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
@@ -551,10 +550,10 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
@@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
- KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
@@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
- EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
- EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
@@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
return schedule
group_schedule_map = {
+ # SM90
+ KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
+ KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
+ KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
+ EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
+ EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
+ EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
+ # SM100
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py
index 6e3038ec..79895305 100644
--- a/python/cutlass_library/sm90_utils.py
+++ b/python/cutlass_library/sm90_utils.py
@@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# the following cases are unsupported by grouped GEMM
if not is_aligned:
return [], []
- if not can_do_tma_epilogue:
- return [], []
if requires_transposed_epilogue:
return [], []
@@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return [], []
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
schedules = []
- if not grouped:
- schedules.append(
- [
- KernelScheduleType.TmaWarpSpecializedCooperative,
- EpilogueScheduleType.TmaWarpSpecializedCooperative
- ])
schedules.append(
[
- KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
- EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
+ ])
+ schedules.append(
+ [
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
return schedules, []
return [], []
@@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
return schedules, stream_k_schedules
- if grouped:
- pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
- cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
- if can_do_tma_epilogue:
- schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
- if can_do_cooperative:
- schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
- return schedules, []
-
schedules = []
- # Pruning: emit Void-C kernels with persistent kernels only
- if level >= 1 or not is_void_c:
+ # Pruning: emit Void-C and Grouped kernels with persistent kernels only
+ if (level >= 1 or not is_void_c) and not grouped:
# Pruning: don't stamp out fp8 kernels with auto schedule
if not is_fp8:
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
@@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedPingpong,
- EpilogueScheduleType.TmaWarpSpecialized
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if can_do_fp8_fast_accum:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
- EpilogueScheduleType.TmaWarpSpecialized
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
])
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
- # Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
+ # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
if not is_fp8 or level >= 1:
- schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
+ schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_fp8_fast_accum:
- schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
- schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
+ if not grouped:
+ schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
+ schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
if can_do_cooperative:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedCooperative,
- default_epilogue
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
+ to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
- default_epilogue
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
+ to_grouped_schedule(default_epilogue, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
@@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
assert not requires_transposed_epilogue
if can_do_cooperative:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedCooperative,
- EpilogueScheduleType.TmaWarpSpecializedCooperative
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
@@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
])
if can_do_fp8_fast_accum:
schedules.append([
- KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
- EpilogueScheduleType.TmaWarpSpecializedCooperative
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
])
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecializedCooperative
])
-
+ # Grouped GEMM do not support Stream-K scheduler
+ if grouped:
+ return schedules, []
return schedules, stream_k_schedules
diff --git a/tools/library/src/grouped_gemm_operation_3x.hpp b/tools/library/src/grouped_gemm_operation_3x.hpp
index c21c82c7..d4b1e26f 100644
--- a/tools/library/src/grouped_gemm_operation_3x.hpp
+++ b/tools/library/src/grouped_gemm_operation_3x.hpp
@@ -204,9 +204,6 @@ protected:
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
- // Single alpha and beta for all groups
- fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
- fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
return Status::kSuccess;
}
@@ -215,6 +212,8 @@ protected:
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast(arguments.alpha);
fusion_args.beta_ptr = static_cast(arguments.beta);
+ fusion_args.alpha_ptr_array = nullptr;
+ fusion_args.beta_ptr_array = nullptr;
return Status::kSuccess;
}
else {