Hopper Grouped GEMM support for FP8 Accum (#2123)
* Add support for fp8accum, with profiler extension * Update .gitignore * contri --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
# PyCache files
|
# PyCache files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
cutlass_library.egg-info/
|
cutlass_library.egg-info/
|
||||||
|
/build*
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||

|

|
||||||
|
|
||||||
[README](./README.md#documentation) > **Active Developers**
|
[README](./README.md#documentation) > **Contributors**
|
||||||
|
|
||||||
# CUTLASS Developers **
|
# CUTLASS Developers **
|
||||||
|
|
||||||
Andrew Kerr (CUTLASS founding member)<br />
|
Andrew Kerr<br />
|
||||||
|
Paul Springer<br />
|
||||||
Dustyn Blasig<br />
|
Dustyn Blasig<br />
|
||||||
Albert Xu<br />
|
Albert Xu<br />
|
||||||
Junkai Wu<br />
|
Junkai Wu<br />
|
||||||
Xiuxia Zhang<br />
|
Xiuxia Zhang<br />
|
||||||
Haicheng Wu (CUTLASS founding member)<br />
|
Haicheng Wu<br />
|
||||||
Jack Yang<br />
|
Jack Yang<br />
|
||||||
Pradeep Ramani (CUTLASS 3.x founding member)<br />
|
Pradeep Ramani<br />
|
||||||
Aditya Atluri<br />
|
Aditya Atluri<br />
|
||||||
Han Li<br />
|
Han Li<br />
|
||||||
Nick Zhao<br />
|
Nick Zhao<br />
|
||||||
@ -20,15 +21,15 @@ Yu-Jung Chen<br />
|
|||||||
Markus Hoehnerbach<br />
|
Markus Hoehnerbach<br />
|
||||||
Honghao Lu<br />
|
Honghao Lu<br />
|
||||||
Mihir Awatramani<br />
|
Mihir Awatramani<br />
|
||||||
Hao Sheng<br />
|
Hao Sheng<br />
|
||||||
Zekun Fan<br />
|
Zekun Fan<br />
|
||||||
Aniket Shivam<br />
|
Aniket Shivam<br />
|
||||||
Siyu Liu<br />
|
Siyu Liu<br />
|
||||||
Richard Cai<br />
|
Richard Cai<br />
|
||||||
Vikas Gupta<br />
|
Vikas Gupta<br />
|
||||||
Ethan Yan<br />
|
Ethan Yan<br />
|
||||||
Vijay Thakkar (CUTLASS 3.x and CuTe founding member)<br />
|
Vijay Thakkar<br />
|
||||||
Cris Cecka (CuTe and CUTLASS 3.x founding member)<br />
|
Cris Cecka<br />
|
||||||
Lawrence Ryan<br />
|
Lawrence Ryan<br />
|
||||||
Qun Song<br />
|
Qun Song<br />
|
||||||
Daniel Ricketts<br />
|
Daniel Ricketts<br />
|
||||||
@ -69,5 +70,61 @@ Shreya Gaur<br />
|
|||||||
|
|
||||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||||
|
|
||||||
|
|
||||||
|
# CUTE Developers
|
||||||
|
|
||||||
|
Cris Cecka<br />
|
||||||
|
Vijay Thakkar<br />
|
||||||
|
|
||||||
|
|
||||||
# CUTLASS Product Manager
|
# CUTLASS Product Manager
|
||||||
|
|
||||||
Matthew Nicely<br />
|
Matthew Nicely<br />
|
||||||
|
|
||||||
|
|
||||||
|
# Former CUTLASS Developers
|
||||||
|
|
||||||
|
Manish Gupta<br />
|
||||||
|
Duane Merrill<br />
|
||||||
|
Piotr Majcher<br />
|
||||||
|
Naila Farooqui<br />
|
||||||
|
Mark Hoemmen<br />
|
||||||
|
Rawn Henry<br />
|
||||||
|
Jin Wang<br />
|
||||||
|
Timmy Liu<br />
|
||||||
|
Manikandan Ananth<br />
|
||||||
|
David Tanner<br />
|
||||||
|
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
|
||||||
|
Tri Dao<br />
|
||||||
|
Jay Shah<br />
|
||||||
|
Timothy Costa<br />
|
||||||
|
Julien Demouth<br />
|
||||||
|
Brian Fahs<br />
|
||||||
|
Michael Garland<br />
|
||||||
|
Michael Goldfarb<br />
|
||||||
|
Mostafa Hagog<br />
|
||||||
|
Fei Hu<br />
|
||||||
|
Alan Kaatz<br />
|
||||||
|
Tina Li<br />
|
||||||
|
Wei Liu<br />
|
||||||
|
Tim Martin<br />
|
||||||
|
Kevin Siu<br />
|
||||||
|
Markus Tavenrath<br />
|
||||||
|
John Tran<br />
|
||||||
|
Vicki Wang<br />
|
||||||
|
Fung Xie<br />
|
||||||
|
Yang Xu<br />
|
||||||
|
Scott Yokim<br />
|
||||||
|
Girish Bharambe<br />
|
||||||
|
Luke Durant<br />
|
||||||
|
Carter Edwards<br />
|
||||||
|
Olivier Giroux<br />
|
||||||
|
Stephen Jones<br />
|
||||||
|
Rishkul Kulkarni<br />
|
||||||
|
Bryce Lelbach<br />
|
||||||
|
Joel McCormack<br />
|
||||||
|
Kyrylo Perelygin<br />
|
||||||
|
Sean Treichler<br />
|
||||||
|
|||||||
@ -234,8 +234,6 @@ struct CollectiveBuilder<
|
|||||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
|
||||||
"KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
|
|
||||||
|
|
||||||
// For fp32 types, map to tf32 MMA value type
|
// For fp32 types, map to tf32 MMA value type
|
||||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||||
@ -267,12 +265,17 @@ struct CollectiveBuilder<
|
|||||||
|
|
||||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
|
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
|
||||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||||
|
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
cute::conditional_t<IsFP8Input,
|
||||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||||
|
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
|
||||||
|
>,
|
||||||
cute::conditional_t<IsFP8Input,
|
cute::conditional_t<IsFP8Input,
|
||||||
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
|
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
using SmemCopyAtomA = void;
|
using SmemCopyAtomA = void;
|
||||||
using SmemCopyAtomB = void;
|
using SmemCopyAtomB = void;
|
||||||
|
|||||||
@ -48,6 +48,7 @@
|
|||||||
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
|
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
|
||||||
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
|
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp"
|
||||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
|
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||||
|
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||||
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||||
|
|
||||||
#if !defined(__CUDACC_RTC__)
|
#if !defined(__CUDACC_RTC__)
|
||||||
|
|||||||
@ -0,0 +1,768 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
|
||||||
|
#include "cutlass/trace.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
|
#include "cute/arch/copy_sm90.hpp"
|
||||||
|
#include "cute/algorithm/functional.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cute/algorithm/gemm.hpp"
|
||||||
|
#include "cute/tensor_predicate.hpp"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// WarpSpecialized Mainloop
|
||||||
|
template <
|
||||||
|
int Stages,
|
||||||
|
class ClusterShape,
|
||||||
|
class KernelSchedule,
|
||||||
|
class TileShape_,
|
||||||
|
class ElementA_,
|
||||||
|
class StrideA_,
|
||||||
|
class ElementB_,
|
||||||
|
class StrideB_,
|
||||||
|
class TiledMma_,
|
||||||
|
class GmemTiledCopyA_,
|
||||||
|
class SmemLayoutAtomA_,
|
||||||
|
class SmemCopyAtomA_,
|
||||||
|
class TransformA_,
|
||||||
|
class GmemTiledCopyB_,
|
||||||
|
class SmemLayoutAtomB_,
|
||||||
|
class SmemCopyAtomB_,
|
||||||
|
class TransformB_>
|
||||||
|
struct CollectiveMma<
|
||||||
|
MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
|
||||||
|
TileShape_,
|
||||||
|
ElementA_,
|
||||||
|
StrideA_,
|
||||||
|
ElementB_,
|
||||||
|
StrideB_,
|
||||||
|
TiledMma_,
|
||||||
|
GmemTiledCopyA_,
|
||||||
|
SmemLayoutAtomA_,
|
||||||
|
SmemCopyAtomA_,
|
||||||
|
TransformA_,
|
||||||
|
GmemTiledCopyB_,
|
||||||
|
SmemLayoutAtomB_,
|
||||||
|
SmemCopyAtomB_,
|
||||||
|
TransformB_>
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Type Aliases
|
||||||
|
//
|
||||||
|
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
|
||||||
|
using TileShape = TileShape_;
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
using StrideA = StrideA_;
|
||||||
|
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||||
|
using ElementB = ElementB_;
|
||||||
|
using StrideB = StrideB_;
|
||||||
|
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
||||||
|
using TiledMma = TiledMma_;
|
||||||
|
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||||
|
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||||
|
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||||
|
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||||
|
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||||
|
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||||
|
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||||
|
using TransformA = TransformA_;
|
||||||
|
using TransformB = TransformB_;
|
||||||
|
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||||
|
|
||||||
|
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||||
|
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||||
|
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||||
|
|
||||||
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
|
|
||||||
|
// One threads per CTA are producers (1 for operand tile)
|
||||||
|
static constexpr int NumProducerThreadEvents = 1;
|
||||||
|
|
||||||
|
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
|
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
|
||||||
|
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
|
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||||
|
|
||||||
|
// Tile along modes in a way that maximizes the TMA box size.
|
||||||
|
using SmemLayoutA = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomA{},
|
||||||
|
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||||
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||||
|
using SmemLayoutB = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomB{},
|
||||||
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||||
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||||
|
|
||||||
|
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||||
|
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||||
|
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||||
|
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||||
|
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||||
|
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||||
|
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||||
|
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||||
|
|
||||||
|
// Assumption: StrideA is congruent with Problem_MK
|
||||||
|
using TMA_A = decltype(make_tma_copy(
|
||||||
|
GmemTiledCopyA{},
|
||||||
|
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
|
||||||
|
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||||
|
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||||
|
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||||
|
// Assumption: StrideB is congruent with Problem_NK
|
||||||
|
using TMA_B = decltype(make_tma_copy(
|
||||||
|
GmemTiledCopyB{},
|
||||||
|
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
|
||||||
|
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||||
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||||
|
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
struct TensorStorage : cute::aligned_struct<128, _0> {
|
||||||
|
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||||
|
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||||
|
} tensors;
|
||||||
|
|
||||||
|
struct TensorMapStorage : cute::aligned_struct<128, _0> {
|
||||||
|
cute::TmaDescriptor smem_tensormap_A;
|
||||||
|
cute::TmaDescriptor smem_tensormap_B;
|
||||||
|
} tensormaps;
|
||||||
|
|
||||||
|
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||||
|
PipelineStorage pipeline;
|
||||||
|
};
|
||||||
|
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||||
|
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
||||||
|
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||||
|
|
||||||
|
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
|
||||||
|
|
||||||
|
// Host side kernel arguments
|
||||||
|
struct Arguments {
|
||||||
|
ElementA const** ptr_A;
|
||||||
|
StrideA dA;
|
||||||
|
ElementB const** ptr_B;
|
||||||
|
StrideB dB;
|
||||||
|
uint32_t mma_promotion_interval = 4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Device side kernel params
|
||||||
|
struct Params {
|
||||||
|
TMA_A tma_load_a;
|
||||||
|
TMA_B tma_load_b;
|
||||||
|
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||||
|
uint32_t mma_promotion_interval = 4;
|
||||||
|
void* tensormaps;
|
||||||
|
ElementA const** ptr_A;
|
||||||
|
StrideA dA;
|
||||||
|
ElementB const** ptr_B;
|
||||||
|
StrideB dB;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(
|
||||||
|
ProblemShape problem_shapes,
|
||||||
|
Arguments const& args,
|
||||||
|
void* workspace) {
|
||||||
|
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
|
||||||
|
// These will be replaced with correct values before the initial tma load.
|
||||||
|
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
|
||||||
|
auto init_M = get<0>(init_shape);
|
||||||
|
auto init_N = get<1>(init_shape);
|
||||||
|
auto init_K = get<2>(init_shape);
|
||||||
|
auto init_L = get<3>(init_shape);
|
||||||
|
|
||||||
|
ElementA const* ptr_A_first_batch = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||||
|
ElementB const* ptr_B_first_batch = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||||
|
|
||||||
|
InternalStrideA stride_a;
|
||||||
|
InternalStrideB stride_b;
|
||||||
|
if constexpr (IsGroupedGemmKernel) {
|
||||||
|
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
|
||||||
|
stride_a = InternalStrideA{};
|
||||||
|
stride_b = InternalStrideB{};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Tensor shapes for Ptr-Array are initialized correctly only here.
|
||||||
|
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
|
||||||
|
init_M = get<0>(problem_shape_MNK);
|
||||||
|
init_N = get<1>(problem_shape_MNK);
|
||||||
|
init_K = get<2>(problem_shape_MNK);
|
||||||
|
|
||||||
|
stride_a = args.dA;
|
||||||
|
stride_b = args.dB;
|
||||||
|
}
|
||||||
|
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a));
|
||||||
|
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b));
|
||||||
|
TMA_A tma_load_a = make_tma_copy(
|
||||||
|
GmemTiledCopyA{},
|
||||||
|
tensor_a,
|
||||||
|
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||||
|
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||||
|
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||||
|
TMA_B tma_load_b = make_tma_copy(
|
||||||
|
GmemTiledCopyB{},
|
||||||
|
tensor_b,
|
||||||
|
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||||
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||||
|
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||||
|
|
||||||
|
void* tensormaps = workspace;
|
||||||
|
|
||||||
|
return {
|
||||||
|
tma_load_a,
|
||||||
|
tma_load_b,
|
||||||
|
TmaTransactionBytes,
|
||||||
|
args.mma_promotion_interval,
|
||||||
|
tensormaps,
|
||||||
|
reinterpret_cast<ElementA const**>(args.ptr_A),
|
||||||
|
args.dA,
|
||||||
|
reinterpret_cast<ElementB const**>(args.ptr_B),
|
||||||
|
args.dB
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
|
||||||
|
constexpr uint32_t NumInputTensors = 2;
|
||||||
|
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
|
||||||
|
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
|
||||||
|
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static bool
|
||||||
|
can_implement(
|
||||||
|
ProblemShape problem_shapes,
|
||||||
|
Arguments const& args) {
|
||||||
|
constexpr int tma_alignment_bits = 128;
|
||||||
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
|
||||||
|
bool implementable = true;
|
||||||
|
if (problem_shapes.is_host_problem_shape_available()) {
|
||||||
|
// Check alignment for all problem sizes
|
||||||
|
for (int i = 0; i < problem_shapes.groups(); i++) {
|
||||||
|
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||||
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
|
||||||
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!implementable) {
|
||||||
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||||
|
}
|
||||||
|
return implementable;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||||
|
static constexpr int K_PIPE_MMAS = 1;
|
||||||
|
static constexpr uint32_t TmaTransactionBytes =
|
||||||
|
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value))+
|
||||||
|
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||||
|
|
||||||
|
// Set up the data needed by this collective for load and mma.
|
||||||
|
// Returns a tuple of tensors. The collective and the kernel layer have the contract that the
|
||||||
|
// returned tuple must contain at least two elements, with the first two elements being:
|
||||||
|
// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||||
|
// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||||
|
// The rest of the tensors can be specified as needed by this collective.
|
||||||
|
template <class ProblemShape_MNKL>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||||
|
using X = Underscore;
|
||||||
|
// Separate out problem shape for convenience
|
||||||
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
const int32_t mock_L = 1;
|
||||||
|
|
||||||
|
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||||
|
// Represent the full tensors -- get these from TMA
|
||||||
|
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l)
|
||||||
|
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l)
|
||||||
|
|
||||||
|
// Make tiled views, defer the slice
|
||||||
|
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||||
|
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||||
|
|
||||||
|
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform a collective-scoped matrix multiply-accumulate
|
||||||
|
// Producer Perspective
|
||||||
|
template <
|
||||||
|
class TensorA, class TensorB,
|
||||||
|
class TensorMapA, class TensorMapB,
|
||||||
|
class KTileIterator, class BlockCoord
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load(
|
||||||
|
Params const& mainloop_params,
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_write,
|
||||||
|
cute::tuple<TensorA, TensorB> const& load_inputs,
|
||||||
|
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
|
||||||
|
BlockCoord const& blk_coord,
|
||||||
|
KTileIterator k_tile_iter, int k_tile_count,
|
||||||
|
int thread_idx,
|
||||||
|
uint32_t block_rank_in_cluster,
|
||||||
|
TensorStorage& shared_tensors) {
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
if (lane_predicate) {
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Prepare the TMA loads for A and B
|
||||||
|
//
|
||||||
|
|
||||||
|
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||||
|
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||||
|
|
||||||
|
Tensor gA_mkl = get<0>(load_inputs);
|
||||||
|
Tensor gB_nkl = get<1>(load_inputs);
|
||||||
|
|
||||||
|
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||||
|
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||||
|
|
||||||
|
// Partition the inputs based on the current block coordinates.
|
||||||
|
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||||
|
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||||
|
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||||
|
|
||||||
|
// Applies the mapping from block_tma_a
|
||||||
|
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||||
|
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||||
|
|
||||||
|
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||||
|
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||||
|
|
||||||
|
uint16_t mcast_mask_a = 0;
|
||||||
|
uint16_t mcast_mask_b = 0;
|
||||||
|
|
||||||
|
// Issue TmaLoads
|
||||||
|
// Maps the tile -> block, value
|
||||||
|
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||||
|
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||||
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mainloop
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||||
|
// LOCK smem_pipe_write for _writing_
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Copy gmem to smem for *k_tile_iter
|
||||||
|
//
|
||||||
|
|
||||||
|
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||||
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||||
|
|
||||||
|
int write_stage = smem_pipe_write.index();
|
||||||
|
copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||||
|
copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||||
|
++k_tile_iter;
|
||||||
|
|
||||||
|
// Advance smem_pipe_write
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_tail(
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_write) {
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
// Issue the epilogue waits
|
||||||
|
if (lane_predicate) {
|
||||||
|
/* This helps avoid early exit of blocks in Cluster
|
||||||
|
* Waits for all stages to either be released (all
|
||||||
|
* Consumer UNLOCKs), or if the stage was never used
|
||||||
|
* then would just be acquired since the phase was
|
||||||
|
* still inverted from make_producer_start_state
|
||||||
|
*/
|
||||||
|
pipeline.producer_tail(smem_pipe_write);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a collective-scoped matrix multiply-accumulate
|
||||||
|
/// Consumer Perspective
|
||||||
|
template <
|
||||||
|
class FrgTensorC
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
mma(MainloopPipeline pipeline,
|
||||||
|
PipelineState smem_pipe_read,
|
||||||
|
FrgTensorC& accum,
|
||||||
|
int k_tile_count,
|
||||||
|
int thread_idx,
|
||||||
|
TensorStorage& shared_tensors,
|
||||||
|
Params const& mainloop_params) {
|
||||||
|
|
||||||
|
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||||
|
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||||
|
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||||
|
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||||
|
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||||
|
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Define C accumulators and A/B partitioning
|
||||||
|
//
|
||||||
|
|
||||||
|
// Layout of warp group to thread mapping
|
||||||
|
|
||||||
|
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||||
|
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||||
|
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||||
|
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||||
|
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||||
|
|
||||||
|
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||||
|
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||||
|
Int<NumThreadsPerWarpGroup>{});
|
||||||
|
|
||||||
|
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||||
|
|
||||||
|
TiledMma tiled_mma;
|
||||||
|
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||||
|
|
||||||
|
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
|
||||||
|
// Allocate "fragments/descriptors"
|
||||||
|
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||||
|
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||||
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||||
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||||
|
|
||||||
|
//
|
||||||
|
// PIPELINED MAIN LOOP
|
||||||
|
//
|
||||||
|
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||||
|
"ERROR : Incorrect number of MMAs in flight");
|
||||||
|
|
||||||
|
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||||
|
PipelineState smem_pipe_release = smem_pipe_read;
|
||||||
|
|
||||||
|
// Prologue GMMAs
|
||||||
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||||
|
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
|
||||||
|
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||||
|
{
|
||||||
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||||
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||||
|
|
||||||
|
if (accumulation.prepare_if_needed()) {
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
int read_stage = smem_pipe_read.index();
|
||||||
|
warpgroup_arrive();
|
||||||
|
// Unroll the K mode manually to set scale D to 1
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||||
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
accumulation.promote_if_needed();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
// Mainloop GMMAs
|
||||||
|
k_tile_count -= prologue_mma_count;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count)
|
||||||
|
{
|
||||||
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||||
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute on k_tile
|
||||||
|
//
|
||||||
|
|
||||||
|
int read_stage = smem_pipe_read.index();
|
||||||
|
|
||||||
|
if (accumulation.prepare_if_needed()) {
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
warpgroup_arrive();
|
||||||
|
// Unroll the K mode manually to set scale D to 1
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||||
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||||
|
warpgroup_wait<K_PIPE_MMAS>();
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
|
||||||
|
accumulation.promote_if_needed();
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
|
|
||||||
|
// Advance smem_pipe_read and smem_pipe_release
|
||||||
|
++smem_pipe_read;
|
||||||
|
++smem_pipe_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulation.promote_residue_if_needed();
|
||||||
|
|
||||||
|
warpgroup_fence_operand(accumulation());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a Consumer Epilogue to release all buffers
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||||
|
// Prologue GMMAs
|
||||||
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||||
|
k_tile_count -= prologue_mma_count;
|
||||||
|
|
||||||
|
smem_pipe_release.advance(k_tile_count);
|
||||||
|
|
||||||
|
// Wait on all GMMAs to complete
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
|
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||||
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
|
++smem_pipe_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods to perform different parts of TMA/Tensormap modifications
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
tensormaps_init(
|
||||||
|
Params const& mainloop_params,
|
||||||
|
TensorMapStorage& shared_tensormaps,
|
||||||
|
int32_t sm_count,
|
||||||
|
int32_t sm_idx) {
|
||||||
|
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
|
||||||
|
|
||||||
|
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
|
||||||
|
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
|
||||||
|
|
||||||
|
if (cute::elect_one_sync()) {
|
||||||
|
// Bringing tensormaps from params to smem for modification later
|
||||||
|
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||||
|
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
|
||||||
|
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||||
|
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
|
||||||
|
|
||||||
|
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
|
||||||
|
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
|
||||||
|
return cute::make_tuple(tma_desc_a, tma_desc_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace address for the global tensor (to be done by single thread)
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void
|
||||||
|
tensormaps_replace_global_address(
|
||||||
|
TensorMapStorage& shared_tensormaps,
|
||||||
|
Params const& mainloop_params,
|
||||||
|
int32_t next_batch) {
|
||||||
|
// Replacing global_address for the next batch
|
||||||
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
||||||
|
mainloop_params.ptr_A[next_batch]);
|
||||||
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
||||||
|
mainloop_params.ptr_B[next_batch]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
|
||||||
|
template <class ProblemShape_MNKL>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void
|
||||||
|
tensormaps_replace_global_tensor_properties(
|
||||||
|
TensorMapStorage& shared_tensormaps,
|
||||||
|
Params const& mainloop_params,
|
||||||
|
int32_t next_group,
|
||||||
|
ProblemShape_MNKL problem_shape_mnkl) {
|
||||||
|
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||||
|
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||||
|
const uint32_t K = get<2>(problem_shape_mnkl);
|
||||||
|
// Replace all dims for consistency
|
||||||
|
constexpr int MaxTensorRank = 5;
|
||||||
|
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
|
||||||
|
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
|
||||||
|
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
|
||||||
|
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
|
||||||
|
|
||||||
|
ElementA const* ptr_A = nullptr;
|
||||||
|
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
|
||||||
|
|
||||||
|
ElementB const* ptr_B = nullptr;
|
||||||
|
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
|
||||||
|
|
||||||
|
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a,
|
||||||
|
prob_shape_A, prob_stride_A);
|
||||||
|
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
|
||||||
|
prob_shape_B, prob_stride_B);
|
||||||
|
|
||||||
|
// Convert strides to byte strides
|
||||||
|
for (uint64_t& stride : prob_stride_A) {
|
||||||
|
stride = (stride * sizeof_bits_v<ElementA>) / 8;
|
||||||
|
}
|
||||||
|
for (uint64_t& stride : prob_stride_B) {
|
||||||
|
stride = (stride * sizeof_bits_v<ElementB>) / 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
||||||
|
prob_shape_A,
|
||||||
|
prob_stride_A);
|
||||||
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
||||||
|
prob_shape_B,
|
||||||
|
prob_stride_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TensorMapA, class TensorMapB, class ProblemShape_MNKL>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void
|
||||||
|
tensormaps_perform_update(
|
||||||
|
TensorMapStorage& shared_tensormaps,
|
||||||
|
Params const& mainloop_params,
|
||||||
|
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
|
||||||
|
ProblemShape_MNKL problem_shape_mnkl,
|
||||||
|
int32_t next_batch) {
|
||||||
|
if (cute::elect_one_sync()) {
|
||||||
|
// Replacing global_address for the next batch
|
||||||
|
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
|
||||||
|
|
||||||
|
if constexpr (IsGroupedGemmKernel) {
|
||||||
|
// Replacing global dims and strides for the next batch
|
||||||
|
tensormaps_replace_global_tensor_properties(shared_tensormaps,
|
||||||
|
mainloop_params, next_batch, problem_shape_mnkl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TensorMapA, class TensorMapB>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void
|
||||||
|
tensormaps_cp_fence_release (
|
||||||
|
TensorMapStorage& shared_tensormaps,
|
||||||
|
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
|
||||||
|
// Entire warp must do this (i.e. it's aligned)
|
||||||
|
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
|
||||||
|
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The entire warp must call this function collectively (that is, the instructions are aligned)
|
||||||
|
template <class TensorMapA, class TensorMapB>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void
|
||||||
|
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
|
||||||
|
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
|
||||||
|
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -336,6 +336,21 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
|||||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
|
||||||
|
// For FP8 kernels
|
||||||
|
template<
|
||||||
|
int Stages_,
|
||||||
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||||
|
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
|
||||||
|
>
|
||||||
|
struct MainloopSm90ArrayTmaGmmaWarpSpecializedFP8
|
||||||
|
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||||
|
static_assert(
|
||||||
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
||||||
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
|
||||||
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
||||||
|
};
|
||||||
|
|
||||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule
|
||||||
template<
|
template<
|
||||||
int Stages_,
|
int Stages_,
|
||||||
|
|||||||
@ -488,6 +488,10 @@ class KernelScheduleType(enum.Enum):
|
|||||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||||
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
||||||
|
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||||
|
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||||
|
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||||
|
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||||
|
|
||||||
TmaWarpSpecialized1SmSm100 = enum_auto()
|
TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||||
TmaWarpSpecialized2SmSm100 = enum_auto()
|
TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||||
@ -514,11 +518,6 @@ class KernelScheduleType(enum.Enum):
|
|||||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||||
|
|
||||||
KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
|
||||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
|
||||||
KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
|
||||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
KernelScheduleTag = {
|
KernelScheduleTag = {
|
||||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||||
@ -551,10 +550,10 @@ KernelScheduleTag = {
|
|||||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||||
|
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
||||||
|
|
||||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
||||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
||||||
@ -598,10 +597,10 @@ KernelScheduleSuffixes = {
|
|||||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||||
|
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||||
KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||||
|
|
||||||
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
||||||
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
||||||
@ -667,8 +666,8 @@ EpilogueScheduleSuffixes = {
|
|||||||
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
||||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
|
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm',
|
||||||
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
|
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm',
|
||||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative',
|
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
|
||||||
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong',
|
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
|
||||||
}
|
}
|
||||||
|
|
||||||
class EpilogueFunctor3x(enum.Enum):
|
class EpilogueFunctor3x(enum.Enum):
|
||||||
@ -686,6 +685,15 @@ def to_grouped_schedule(schedule, grouped):
|
|||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
group_schedule_map = {
|
group_schedule_map = {
|
||||||
|
# SM90
|
||||||
|
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||||
|
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||||
|
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||||
|
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
|
||||||
|
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||||
|
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||||
|
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
|
||||||
|
# SM100
|
||||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
||||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
||||||
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
||||||
|
|||||||
@ -494,8 +494,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
# the following cases are unsupported by grouped GEMM
|
# the following cases are unsupported by grouped GEMM
|
||||||
if not is_aligned:
|
if not is_aligned:
|
||||||
return [], []
|
return [], []
|
||||||
if not can_do_tma_epilogue:
|
|
||||||
return [], []
|
|
||||||
if requires_transposed_epilogue:
|
if requires_transposed_epilogue:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
@ -513,16 +511,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
return [], []
|
return [], []
|
||||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
||||||
schedules = []
|
schedules = []
|
||||||
if not grouped:
|
|
||||||
schedules.append(
|
|
||||||
[
|
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|
||||||
])
|
|
||||||
schedules.append(
|
schedules.append(
|
||||||
[
|
[
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||||
|
])
|
||||||
|
schedules.append(
|
||||||
|
[
|
||||||
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||||
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||||
])
|
])
|
||||||
return schedules, []
|
return schedules, []
|
||||||
return [], []
|
return [], []
|
||||||
@ -586,18 +583,9 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
|
|
||||||
return schedules, stream_k_schedules
|
return schedules, stream_k_schedules
|
||||||
|
|
||||||
if grouped:
|
|
||||||
pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum
|
|
||||||
cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum
|
|
||||||
if can_do_tma_epilogue:
|
|
||||||
schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong])
|
|
||||||
if can_do_cooperative:
|
|
||||||
schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative])
|
|
||||||
return schedules, []
|
|
||||||
|
|
||||||
schedules = []
|
schedules = []
|
||||||
# Pruning: emit Void-C kernels with persistent kernels only
|
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
|
||||||
if level >= 1 or not is_void_c:
|
if (level >= 1 or not is_void_c) and not grouped:
|
||||||
# Pruning: don't stamp out fp8 kernels with auto schedule
|
# Pruning: don't stamp out fp8 kernels with auto schedule
|
||||||
if not is_fp8:
|
if not is_fp8:
|
||||||
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
||||||
@ -610,28 +598,29 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||||
if not is_fp8 or level >= 1:
|
if not is_fp8 or level >= 1:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
|
||||||
EpilogueScheduleType.TmaWarpSpecialized
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||||
])
|
])
|
||||||
if can_do_fp8_fast_accum:
|
if can_do_fp8_fast_accum:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
|
||||||
EpilogueScheduleType.TmaWarpSpecialized
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
||||||
])
|
])
|
||||||
|
|
||||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||||
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
|
# Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
|
||||||
if not is_fp8 or level >= 1:
|
if not is_fp8 or level >= 1:
|
||||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpong, default_epilogue])
|
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||||
|
|
||||||
if can_do_fp8_fast_accum:
|
if can_do_fp8_fast_accum:
|
||||||
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
if not grouped:
|
||||||
schedules.append([KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, default_epilogue])
|
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
||||||
|
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
||||||
|
|
||||||
if can_do_cooperative:
|
if can_do_cooperative:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||||
default_epilogue
|
to_grouped_schedule(default_epilogue, grouped)
|
||||||
])
|
])
|
||||||
stream_k_schedules.append([
|
stream_k_schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||||
@ -639,8 +628,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
])
|
])
|
||||||
if can_do_fp8_fast_accum:
|
if can_do_fp8_fast_accum:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||||
default_epilogue
|
to_grouped_schedule(default_epilogue, grouped)
|
||||||
])
|
])
|
||||||
stream_k_schedules.append([
|
stream_k_schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||||
@ -652,8 +641,8 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
assert not requires_transposed_epilogue
|
assert not requires_transposed_epilogue
|
||||||
if can_do_cooperative:
|
if can_do_cooperative:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||||
])
|
])
|
||||||
stream_k_schedules.append([
|
stream_k_schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||||
@ -661,14 +650,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
|||||||
])
|
])
|
||||||
if can_do_fp8_fast_accum:
|
if can_do_fp8_fast_accum:
|
||||||
schedules.append([
|
schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
||||||
])
|
])
|
||||||
stream_k_schedules.append([
|
stream_k_schedules.append([
|
||||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
||||||
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||||
])
|
])
|
||||||
|
# Grouped GEMM do not support Stream-K scheduler
|
||||||
|
if grouped:
|
||||||
|
return schedules, []
|
||||||
return schedules, stream_k_schedules
|
return schedules, stream_k_schedules
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -204,9 +204,6 @@ protected:
|
|||||||
fusion_args.beta_ptr = nullptr;
|
fusion_args.beta_ptr = nullptr;
|
||||||
fusion_args.alpha_ptr_array = nullptr;
|
fusion_args.alpha_ptr_array = nullptr;
|
||||||
fusion_args.beta_ptr_array = nullptr;
|
fusion_args.beta_ptr_array = nullptr;
|
||||||
// Single alpha and beta for all groups
|
|
||||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
|
||||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
|
||||||
|
|
||||||
return Status::kSuccess;
|
return Status::kSuccess;
|
||||||
}
|
}
|
||||||
@ -215,6 +212,8 @@ protected:
|
|||||||
fusion_args.beta = 0;
|
fusion_args.beta = 0;
|
||||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(arguments.alpha);
|
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(arguments.alpha);
|
||||||
fusion_args.beta_ptr = static_cast<ElementCompute const*>(arguments.beta);
|
fusion_args.beta_ptr = static_cast<ElementCompute const*>(arguments.beta);
|
||||||
|
fusion_args.alpha_ptr_array = nullptr;
|
||||||
|
fusion_args.beta_ptr_array = nullptr;
|
||||||
return Status::kSuccess;
|
return Status::kSuccess;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|||||||
Reference in New Issue
Block a user