diff --git a/.gitignore b/.gitignore index acddb1f9..e7a02687 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ # PyCache files __pycache__/ -cutlass_library.egg-info/ \ No newline at end of file +cutlass_library.egg-info/ +/build* diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 1ef06a36..843ed365 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -1,17 +1,18 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") -[README](./README.md#documentation) > **Active Developers** +[README](./README.md#documentation) > **Contributors** # CUTLASS Developers ** -Andrew Kerr (CUTLASS founding member)
+Andrew Kerr
+Paul Springer
Dustyn Blasig
Albert Xu
Junkai Wu
Xiuxia Zhang
-Haicheng Wu (CUTLASS founding member)
+Haicheng Wu
Jack Yang
-Pradeep Ramani (CUTLASS 3.x founding member)
+Pradeep Ramani
Aditya Atluri
Han Li
Nick Zhao
@@ -20,15 +21,15 @@ Yu-Jung Chen
Markus Hoehnerbach
Honghao Lu
Mihir Awatramani
-Hao Sheng
+Hao Sheng
Zekun Fan
-Aniket Shivam
+Aniket Shivam
Siyu Liu
Richard Cai
Vikas Gupta
Ethan Yan
-Vijay Thakkar (CUTLASS 3.x and CuTe founding member)
-Cris Cecka (CuTe and CUTLASS 3.x founding member)
+Vijay Thakkar
+Cris Cecka
Lawrence Ryan
Qun Song
Daniel Ricketts
@@ -69,5 +70,61 @@ Shreya Gaur
** _The list is sorted in order of the author's first contribution to the CUTLASS project._ + +# CUTE Developers + +Cris Cecka
+Vijay Thakkar
+ + # CUTLASS Product Manager + Matthew Nicely
+ + +# Former CUTLASS Developers + +Manish Gupta
+Duane Merrill
+Piotr Majcher
+Naila Farooqui
+Mark Hoemmen
+Rawn Henry
+Jin Wang
+Timmy Liu
+Manikandan Ananth
+David Tanner
+ + +# Acknowledgements + +Tri Dao
+Jay Shah
+Timothy Costa
+Julien Demouth
+Brian Fahs
+Michael Garland
+Michael Goldfarb
+Mostafa Hagog
+Fei Hu
+Alan Kaatz
+Tina Li
+Wei Liu
+Tim Martin
+Kevin Siu
+Markus Tavenrath
+John Tran
+Vicki Wang
+Fung Xie
+Yang Xu
+Scott Yokim
+Girish Bharambe
+Luke Durant
+Carter Edwards
+Olivier Giroux
+Stephen Jones
+Rishkul Kulkarni
+Bryce Lelbach
+Joel McCormack
+Kyrylo Perelygin
+Sean Treichler
diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 7736dbee..504abcc6 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -234,8 +234,6 @@ struct CollectiveBuilder< KernelPtrArrayTmaWarpSpecializedCooperative, KernelPtrArrayTmaWarpSpecializedPingpong>); static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), - "KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -267,12 +265,17 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + /* For FP8 use a separate mainloop compared to other datatypes */ using DispatchPolicy = cute::conditional_t, - /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90ArrayTmaGmmaWarpSpecialized + >, cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>>; + MainloopSm90TmaGmmaWarpSpecialized + > + >; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index f6a0dcb3..e33c06a7 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -46,8 +46,9 @@ #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" #if !defined(__CUDACC_RTC__) diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 00000000..b3d857ff --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/tensor.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t mma_promotion_interval = 4; + void* tensormaps; + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + ElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A); + ElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + args.mma_promotion_interval, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + accumulation.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + ElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + ElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 8747f48b..0936eb25 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -336,6 +336,21 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized { "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies"); }; +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm +// For FP8 kernels +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative +> +struct MainloopSm90ArrayTmaGmmaWarpSpecializedFP8 + : MainloopSm90ArrayTmaGmmaWarpSpecialized { + static_assert( + cute::is_base_of_v || + cute::is_base_of_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies"); +}; + // n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule template< int Stages_, diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index bc2cc7b1..d5f02db0 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum): TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() ImplicitTmaWarpSpecializedSm90 = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() @@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum): Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() - KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto() - KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() - KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto() - KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() - # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', @@ -551,10 +550,10 @@ KernelScheduleTag = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", @@ -598,10 +597,10 @@ KernelScheduleSuffixes = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', - KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', @@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm', EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative', - EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', } class EpilogueFunctor3x(enum.Enum): @@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped): return schedule group_schedule_map = { + # SM90 + KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, + KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, + # SM100 KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 6e3038ec..79895305 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, # the following cases are unsupported by grouped GEMM if not is_aligned: return [], [] - if not can_do_tma_epilogue: - return [], [] if requires_transposed_epilogue: return [], [] @@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, return [], [] if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: schedules = [] - if not grouped: - schedules.append( - [ - KernelScheduleType.TmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ]) schedules.append( [ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) ]) return schedules, [] return [], [] @@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, return schedules, stream_k_schedules - if grouped: - pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum - cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum - if can_do_tma_epilogue: - schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong]) - if can_do_cooperative: - schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative]) - return schedules, [] - schedules = [] - # Pruning: emit Void-C kernels with persistent kernels only - if level >= 1 or not is_void_c: + # Pruning: emit Void-C and Grouped kernels with persistent kernels only + if (level >= 1 or not is_void_c) and not grouped: # Pruning: don't stamp out fp8 kernels with auto schedule if not is_fp8: schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue]) @@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, # Inconsistency: fp8 pingpong only gets stamped out with fast accum if not is_fp8 or level >= 1: schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpong, - EpilogueScheduleType.TmaWarpSpecialized + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) ]) if can_do_fp8_fast_accum: schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecialized + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) ]) if CudaToolkitVersionSatisfies(cuda_version, 12, 1): - # Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue + # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue if not is_fp8 or level >= 1: - schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue]) + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) if can_do_fp8_fast_accum: - schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) - schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue]) + if not grouped: + schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)]) if can_do_cooperative: schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - default_epilogue + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) ]) stream_k_schedules.append([ KernelScheduleType.TmaWarpSpecializedCooperative, @@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, ]) if can_do_fp8_fast_accum: schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, - default_epilogue + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(default_epilogue, grouped) ]) stream_k_schedules.append([ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, @@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, assert not requires_transposed_epilogue if can_do_cooperative: schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecializedCooperative + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) ]) stream_k_schedules.append([ KernelScheduleType.TmaWarpSpecializedCooperative, @@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, ]) if can_do_fp8_fast_accum: schedules.append([ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecializedCooperative + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) ]) stream_k_schedules.append([ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative ]) - + # Grouped GEMM do not support Stream-K scheduler + if grouped: + return schedules, [] return schedules, stream_k_schedules diff --git a/tools/library/src/grouped_gemm_operation_3x.hpp b/tools/library/src/grouped_gemm_operation_3x.hpp index c21c82c7..d4b1e26f 100644 --- a/tools/library/src/grouped_gemm_operation_3x.hpp +++ b/tools/library/src/grouped_gemm_operation_3x.hpp @@ -204,9 +204,6 @@ protected: fusion_args.beta_ptr = nullptr; fusion_args.alpha_ptr_array = nullptr; fusion_args.beta_ptr_array = nullptr; - // Single alpha and beta for all groups - fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; - fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; return Status::kSuccess; } @@ -215,6 +212,8 @@ protected: fusion_args.beta = 0; fusion_args.alpha_ptr = static_cast(arguments.alpha); fusion_args.beta_ptr = static_cast(arguments.beta); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; return Status::kSuccess; } else {