1125 lines
50 KiB
C++
1125 lines
50 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cutlass/arch/arch.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
|
|
#include "cute/layout.hpp"
|
|
#include "cute/numeric/integral_constant.hpp" // cute::false_type
|
|
#include "cute/atom/copy_traits_sm100.hpp"
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass::detail {
|
|
|
|
template <class T, template <int...> class U>
|
|
struct is_kernel_tag_of : cute::false_type {};
|
|
|
|
template <template <int...> class U, int... Args>
|
|
struct is_kernel_tag_of<U<Args...>, U> : cute::true_type {};
|
|
|
|
template <class T, template <int...> class U>
|
|
constexpr bool is_kernel_tag_of_v = is_kernel_tag_of<T, U>::value;
|
|
|
|
template <class T, template <int,bool> class U>
|
|
struct is_asymmetric_dma_kernel_tag_of : cute::false_type {};
|
|
|
|
template <template <int, bool> class U, int I0, bool B0>
|
|
struct is_asymmetric_dma_kernel_tag_of<U<I0, B0>, U> : cute::true_type {};
|
|
|
|
template <class T, template <int, bool> class U>
|
|
constexpr bool is_asymmetric_dma_kernel_tag_of_v = \
|
|
is_asymmetric_dma_kernel_tag_of<T, U>::value;
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass::gemm {
|
|
using namespace cute;
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
|
|
enum class KernelInputTransformType {
|
|
FastF32,
|
|
InterleavedComplexTF32,
|
|
MixedInput
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace kernel::detail {
|
|
|
|
// Has_SwapAB<T>::value will be true only if:
|
|
// class T has member SwapAB and T::SwapAB is true
|
|
template <typename T, typename = void>
|
|
struct Has_SwapAB { static constexpr bool value = false; };
|
|
|
|
template <typename T>
|
|
struct Has_SwapAB <T, CUTE_STL_NAMESPACE::void_t<decltype(T::SwapAB)>>
|
|
{ static constexpr bool value = T::SwapAB; };
|
|
|
|
template <typename T>
|
|
static constexpr bool Has_SwapAB_v = Has_SwapAB<T>::value;
|
|
|
|
// additional producer warp role check for block scaling mainloop
|
|
template<typename T>
|
|
struct HasAuxiliaryLoad : cute::false_type{};
|
|
|
|
template <typename T>
|
|
static constexpr bool HasAuxiliaryLoad_v = HasAuxiliaryLoad<T>::value;
|
|
|
|
} // namespace kernel::detail
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// Kernel schedule policies (the base class tags, one for each kernel layer file)
|
|
//
|
|
struct KernelMultistage { };
|
|
struct KernelCpAsyncWarpSpecialized { };
|
|
struct KernelCpAsyncWarpSpecializedPingpong { };
|
|
struct KernelCpAsyncWarpSpecializedCooperative { };
|
|
struct KernelTma { };
|
|
struct KernelTmaWarpSpecialized { };
|
|
struct KernelTmaWarpSpecializedPingpong {
|
|
static constexpr int SchedulerPipelineStageCount = 0;
|
|
};
|
|
struct KernelTmaWarpSpecializedCooperative {
|
|
static constexpr int SchedulerPipelineStageCount = 0;
|
|
};
|
|
|
|
struct KernelPtrArrayTmaWarpSpecializedCooperative { };
|
|
struct KernelPtrArrayTmaWarpSpecializedPingpong { };
|
|
|
|
// FP8 related policies (including Blocked Scaled Accumulation)
|
|
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
|
|
struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { };
|
|
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { };
|
|
|
|
// Policies to opt into mixed type GEMMs
|
|
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
|
struct KernelTmaWarpSpecializedPingpongMixedInput : KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelTmaWarpSpecializedCooperativeMixedInput: KernelTmaWarpSpecializedCooperative { };
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// Builder dispatch policies (not a part of the main CUTLASS layers, simply used to opt into
|
|
// specific collective builder dispatches)
|
|
//
|
|
|
|
// FP8 related policies (including Fast Accumulation)
|
|
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
|
|
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
|
|
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
|
|
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum : KernelPtrArrayTmaWarpSpecializedPingpong { };
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Policies for dispatch of epilogue
|
|
struct EpilogueDefault { };
|
|
struct EpilogueTransposed { };
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// Collective Mainloop Policies
|
|
//
|
|
|
|
// 2 stage pipeline through 1 stage in smem, 1 in rmem, WITHOUT predicated gmem loads
|
|
struct MainloopSm70TwoStageUnpredicated {
|
|
constexpr static int Stages = 2;
|
|
using ArchTag = arch::Sm70;
|
|
using Schedule = KernelMultistage;
|
|
using ClusterShape = Shape<_1,_1,_1>;
|
|
};
|
|
|
|
// 2 stage pipeline through 1 stage in smem, 1 in rmem, with predicated gmem loads
|
|
struct MainloopSm70TwoStage {
|
|
constexpr static int Stages = 2;
|
|
using ArchTag = arch::Sm70;
|
|
using Schedule = KernelMultistage;
|
|
using ClusterShape = Shape<_1,_1,_1>;
|
|
};
|
|
|
|
// n-buffer in smem (cp.async), pipelined with registers, WITHOUT predicated gmem loads
|
|
template<int Stages_>
|
|
struct MainloopSm80CpAsyncUnpredicated {
|
|
constexpr static int Stages = Stages_;
|
|
using ArchTag = arch::Sm80;
|
|
using Schedule = KernelMultistage;
|
|
using ClusterShape = Shape<_1,_1,_1>;
|
|
};
|
|
|
|
// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm80CpAsync {
|
|
constexpr static int Stages = Stages_;
|
|
using ArchTag = cute::conditional_t<(size(ClusterShape_{}) > 1), arch::Sm90, arch::Sm80>;
|
|
using Schedule = KernelMultistage;
|
|
using ClusterShape = ClusterShape_;
|
|
};
|
|
|
|
// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelCpAsyncWarpSpecialized
|
|
>
|
|
struct MainloopSm90CpAsyncGmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
};
|
|
|
|
// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelCpAsyncWarpSpecialized
|
|
>
|
|
struct MainloopSm90CpAsyncGmmaRmemAWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
int PipelineAsyncMmaStages_ = 1
|
|
>
|
|
struct MainloopSm90TmaGmma {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelTma;
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
|
// With GMMA's A data from registers.
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecialized
|
|
>
|
|
struct MainloopSm90TmaGmmaRmemAWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
static_assert(
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
|
"KernelSchedule must be one of the warp specialized policies");
|
|
};
|
|
|
|
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecialized
|
|
>
|
|
struct MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
static_assert(
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
|
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
|
"KernelSchedule must be one of the warp specialized policies");
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
|
// For FP8 kernels
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecialized
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecializedFP8
|
|
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
|
static_assert(
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecialized> ||
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpong> ||
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
|
"KernelSchedule must be one of the warp specialized policies");
|
|
};
|
|
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
|
// For FP8 kernels with Block Scaling
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
|
|
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
|
static_assert(
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> ||
|
|
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum>,
|
|
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
constexpr static int PipelineAsyncMmaStages = 1;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
static_assert(
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
|
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
|
|
// For FP8 kernels
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90ArrayTmaGmmaWarpSpecializedFP8
|
|
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
|
static_assert(
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
|
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper sparse GMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecializedSparse {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
};
|
|
|
|
// For slow-accumulation sparse FP8 kernels
|
|
template<
|
|
int Stages,
|
|
class ClusterShape = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecializedSparseFP8
|
|
: MainloopSm90TmaGmmaWarpSpecializedSparse<Stages, ClusterShape, KernelSchedule> {
|
|
};
|
|
|
|
// Mixed precision version n-buffer in rmem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
static_assert(
|
|
cute::is_same_v<Schedule, KernelPtrArrayTmaWarpSpecializedCooperative> ||
|
|
cute::is_same_v<Schedule, KernelPtrArrayTmaWarpSpecializedPingpong>,
|
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
|
|
};
|
|
|
|
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
|
// For FP8 kernels with Block Scaling
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum
|
|
>
|
|
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
|
|
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
|
static_assert(
|
|
cute::is_any_of_v<
|
|
KernelSchedule,
|
|
KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum,
|
|
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
|
|
>,
|
|
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
|
|
};
|
|
|
|
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelTmaWarpSpecializedSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Gemm with block scaling factors
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelTmaWarpSpecializedBlockScaledSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelTmaWarpSpecializedMmaTransformSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelPtrArrayTmaWarpSpecializedMmaTransformSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Sparse Gemm
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelSparseTmaWarpSpecializedSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Sparse Gemm with block scaling factors
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelSparseTmaWarpSpecializedBlockScaledSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// InputTransform GEMM
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelTmaWarpSpecializedInputTransformSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelPtrArrayTmaWarpSpecializedSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Ptr-Array Block Scaled GEMM
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
// Ptr-Array InputTransform GEMM
|
|
template<
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_
|
|
>
|
|
struct KernelPtrArrayTmaWarpSpecializedInputTransformSm100 final {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
|
};
|
|
|
|
|
|
// SM120 kernel schedules
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelTmaWarpSpecializedCooperativeSm120 : KernelTmaWarpSpecializedCooperative {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelTmaWarpSpecializedPingpongSm120 : KernelTmaWarpSpecializedPingpong {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelTmaWarpSpecializedCooperativeBlockScaledSm120 : KernelTmaWarpSpecializedCooperative {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelTmaWarpSpecializedPingpongBlockScaledSm120 : KernelTmaWarpSpecializedPingpong {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
// SM120 dense Ptr-array kernel schedules
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelPtrArrayTmaWarpSpecializedCooperativeSm120 : KernelPtrArrayTmaWarpSpecializedCooperative {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelPtrArrayTmaWarpSpecializedPingpongSm120 : KernelPtrArrayTmaWarpSpecializedPingpong {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 : KernelPtrArrayTmaWarpSpecializedCooperative {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_>
|
|
struct KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 : KernelPtrArrayTmaWarpSpecializedPingpong {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
};
|
|
|
|
// SM120 sparse kernel schedules
|
|
template< int SchedulerPipelineStageCount_, bool isAsymmetric_>
|
|
struct KernelTmaWarpSpecializedCooperativeSparseSm120 {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr bool isAsymmetric = isAsymmetric_;
|
|
};
|
|
|
|
template< int SchedulerPipelineStageCount_, bool isAsymmetric_>
|
|
struct KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120 {
|
|
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
static constexpr bool isAsymmetric = isAsymmetric_;
|
|
};
|
|
|
|
// Auxiliary Load Tag.
|
|
|
|
namespace kernel::detail {
|
|
|
|
template<
|
|
int Stages,
|
|
class ClusterShape,
|
|
class KernelSchedule
|
|
>
|
|
struct HasAuxiliaryLoad<
|
|
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<
|
|
Stages,
|
|
ClusterShape,
|
|
KernelSchedule
|
|
>
|
|
> : cute::true_type{};
|
|
|
|
template<
|
|
int Stages,
|
|
class ClusterShape,
|
|
class KernelSchedule
|
|
>
|
|
struct HasAuxiliaryLoad<
|
|
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<
|
|
Stages,
|
|
ClusterShape,
|
|
KernelSchedule
|
|
>
|
|
> : cute::true_type{};
|
|
|
|
} // namespace kernel::detail
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
// Collective Builder Tag Property
|
|
//
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// SM100 Dispatch Policies
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Builder Tag Base Dispatch Policies
|
|
struct KernelSchedule1Sm {};
|
|
struct KernelSchedule2Sm {};
|
|
struct KernelScheduleSm100 {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Dense GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy
|
|
// Dense GEMM: Specialize for 1SM vs 2SM
|
|
struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder
|
|
struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Ptr-Array Dense GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Dense GEMM + (Ptr Array or Group GEMM)
|
|
struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {};
|
|
// Ptr-Array Dense GEMM: Specialize for 1SM vs 2SM
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {};
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Blockwise GEMM + Ptr-Array GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm100Blockwise : KernelScheduleSm100 {};
|
|
struct KernelTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100Blockwise {};
|
|
struct KernelTmaWarpSpecializedBlockwise2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100Blockwise {};
|
|
|
|
struct KernelScheduleSm100PtrArrayBlockwise : KernelScheduleSm100Blockwise {};
|
|
struct KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayBlockwise {};
|
|
struct KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayBlockwise {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Planar Complex GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm100PlanarComplexGemm : KernelScheduleSm100{};
|
|
// Planar Complex GEMM: Specialize for 1SM vs 2SM
|
|
struct KernelTmaWarpSpecialized1SmPlanarComplexSm100 final : KernelSchedule1Sm, KernelScheduleSm100PlanarComplexGemm { };
|
|
struct KernelTmaWarpSpecialized2SmPlanarComplexSm100 final : KernelSchedule2Sm, KernelScheduleSm100PlanarComplexGemm { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Ptr-Array Planar Complex GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Planar Complex GEMM + (Ptr Array or Group GEMM)
|
|
struct KernelScheduleSm100PtrArrayPlanarComplexGemm : KernelScheduleSm100PlanarComplexGemm {};
|
|
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmPlanarComplexSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayPlanarComplexGemm {};
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmPlanarComplexSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayPlanarComplexGemm {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 FastF32 (9xBF16) GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm100FastFP32Gemm : KernelScheduleSm100 {};
|
|
struct KernelTmaWarpSpecializedFastFP32SmemSm100 : KernelScheduleSm100FastFP32Gemm { };
|
|
// Dispatch policies without smem load the A operand from tmem
|
|
struct KernelTmaWarpSpecialized1SmFastFP32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100FastFP32Gemm { };
|
|
struct KernelTmaWarpSpecialized2SmFastFP32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100FastFP32Gemm { };
|
|
// Dispatch policies with smem load the A operand from smem
|
|
struct KernelTmaWarpSpecialized1SmFastFP32SmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedFastFP32SmemSm100 { };
|
|
struct KernelTmaWarpSpecialized2SmFastFP32SmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedFastFP32SmemSm100 { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Mixed Precision Input GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm100MixedInputGemm : KernelScheduleSm100 {};
|
|
struct KernelTmaWarpSpecializedMixedInputSmemSm100 : KernelScheduleSm100MixedInputGemm { };
|
|
struct KernelTmaWarpSpecialized1SmMixedInputSm100 final : KernelSchedule1Sm, KernelScheduleSm100MixedInputGemm { };
|
|
struct KernelTmaWarpSpecialized1SmMixedInputSmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedMixedInputSmemSm100 { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Ptr-Array FastF32 (9xBF16) GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Ptr-Array Transform GEMM: Specialize for 1SM vs 2SM FastF32 GEMM
|
|
struct KernelScheduleSm100PtrArrayFastFP32Gemm : KernelScheduleSm100FastFP32Gemm {};
|
|
struct KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 : KernelScheduleSm100PtrArrayFastFP32Gemm { };
|
|
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmFastFP32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayFastFP32Gemm { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayFastFP32Gemm { };
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmFastFP32SmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32SmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Sparse GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSparseGemmSm100 : KernelScheduleSm100 {};
|
|
// Sparse GEMM: Specialize for 1SM vs 2SM
|
|
struct KernelSparseTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSparseGemmSm100 { };
|
|
struct KernelSparseTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSparseGemmSm100 { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 BlockScaled Dense GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleBlockScaledGemmSm100 : KernelScheduleSm100 {};
|
|
struct KernelScheduleMxNvf4Sm100 : KernelScheduleBlockScaledGemmSm100 {};
|
|
struct KernelScheduleMxf8f6f4Sm100 : KernelScheduleBlockScaledGemmSm100 {};
|
|
// Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM
|
|
struct KernelTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 { };
|
|
struct KernelTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelScheduleBlockScaledGemmSm100 { };
|
|
struct KernelTmaWarpSpecialized1SmNvf4Sm100 final : KernelSchedule1Sm, KernelScheduleMxNvf4Sm100 { };
|
|
struct KernelTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
|
|
struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelScheduleMxNvf4Sm100 { };
|
|
struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
|
|
struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { };
|
|
struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { };
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// BlockScaled Dense GEMM + (Ptr Array or Group GEMM)
|
|
struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {};
|
|
struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {};
|
|
struct KernelSchedulePtrArrayMxf8f6f4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {};
|
|
// Ptr-Array Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxNvf4Sm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxNvf4Sm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxNvf4Sm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxNvf4Sm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
|
|
struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 BlockScaled Sparse GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleBlockScaledSparseGemmSm100 : KernelScheduleSm100 {};
|
|
struct KernelScheduleSparseMxNvf4Sm100 : KernelScheduleBlockScaledSparseGemmSm100 {};
|
|
struct KernelScheduleSparseMxf8f6f4Sm100 : KernelScheduleBlockScaledSparseGemmSm100 {};
|
|
// Block Scaled Sparse GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM
|
|
struct KernelSparseTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledSparseGemmSm100 {};
|
|
struct KernelSparseTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelScheduleBlockScaledSparseGemmSm100 {};
|
|
struct KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleSparseMxf8f6f4Sm100 { };
|
|
struct KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleSparseMxf8f6f4Sm100 { };
|
|
struct KernelSparseTmaWarpSpecialized1SmNvf4Sm100 final : KernelSchedule1Sm, KernelScheduleSparseMxNvf4Sm100 { };
|
|
struct KernelSparseTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2Sm, KernelScheduleSparseMxNvf4Sm100 { };
|
|
struct KernelSparseTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelScheduleSparseMxNvf4Sm100 { };
|
|
struct KernelSparseTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleSparseMxNvf4Sm100 { };
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// SM120 Dispatch Policies
|
|
//
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Builder Tag Base Dispatch Policies
|
|
struct KernelScheduleSm120 {};
|
|
struct KernelScheduleAcc2x4Sm120 {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM100 Dense GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSm120DenseGemm : KernelScheduleSm120 {};
|
|
// Dense GEMM: Specialize for instruction type
|
|
struct KernelScheduleF8f6f4Sm120 final : KernelScheduleSm120DenseGemm {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM120 BlockScaled GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleBlockScaledGemmSm120 : KernelScheduleSm120 {};
|
|
struct KernelScheduleMxf8f6f4Sm120 : KernelScheduleBlockScaledGemmSm120 {};
|
|
struct KernelScheduleMxNvf4Sm120 : KernelScheduleBlockScaledGemmSm120 {};
|
|
// Block Scaled GEMM: Specialize for instruction type, scale factor vector size.
|
|
struct KernelTmaWarpSpecializedNvf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedCooperative { };
|
|
struct KernelTmaWarpSpecializedPingpongNvf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelTmaWarpSpecializedMxf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedCooperative { };
|
|
struct KernelTmaWarpSpecializedPingpongMxf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedCooperative { };
|
|
struct KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedPingpong { };
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM120 Sparse GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleSparseGemmSm120 : KernelScheduleSm120 {};
|
|
// Sparse GEMM: Specialize for instruction type
|
|
struct KernelScheduleSparseF8f6f4Sm120 final : KernelScheduleSparseGemmSm120 {};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// SM120 BlockScaled Sparse GEMM Dispatch Policies
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct KernelScheduleBlockScaledSparseGemmSm120 : KernelScheduleSm120 {};
|
|
struct KernelScheduleSparseMxNvf4Sm120 : KernelScheduleBlockScaledSparseGemmSm120 {};
|
|
struct KernelScheduleSparseMxf8f6f4Sm120 : KernelScheduleBlockScaledSparseGemmSm120 {};
|
|
// Block Scaled Sparse GEMM: Specialize for instruction type, scale factor vector size, Acc2x4
|
|
struct KernelSparseTmaWarpSpecializedNvf4Sm120 final : KernelScheduleSparseMxNvf4Sm120 { };
|
|
struct KernelSparseTmaWarpSpecializedMxf4Sm120 final : KernelScheduleSparseMxNvf4Sm120 { };
|
|
struct KernelSparseTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120 { };
|
|
struct KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120, KernelScheduleAcc2x4Sm120 { };
|
|
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
using Schedule = KernelTmaWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
constexpr static bool IsOverlappingAccum = false;
|
|
};
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
using Schedule = KernelTmaWarpSpecializedMmaTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
constexpr static bool IsOverlappingAccum = false;
|
|
};
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
using Schedule = KernelPtrArrayTmaWarpSpecializedMmaTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
constexpr static bool IsOverlappingAccum = false;
|
|
};
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecializedBlockScaled {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1;
|
|
using Schedule = KernelTmaWarpSpecializedBlockScaledSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
};
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecializedSparse {
|
|
constexpr static int Stages = Stages_;
|
|
constexpr static int MetadataS2TStages = 4;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1;
|
|
using Schedule = KernelSparseTmaWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
};
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecializedBlockScaledSparse {
|
|
constexpr static int Stages = Stages_;
|
|
constexpr static int MetadataS2TStages = 4;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1;
|
|
using Schedule = KernelSparseTmaWarpSpecializedBlockScaledSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
};
|
|
|
|
// n-buffer in smem, pipelined with Blackwell Fast FP32 kernel with UMMA (HwScaled) and TMA,
|
|
// Warp specialized dynamic schedule
|
|
template<
|
|
// Number of Pipeline stages for
|
|
// MainloopLoad <-> Conversion <-> MainLoad
|
|
int Load2TransformPipelineStageCount_,
|
|
// Number of Pipeline stages for
|
|
// MainloopLoad <-> Conversion <-> MainLoad
|
|
int Transform2MmaPipelineStageCount_,
|
|
// TileScheduler pipeline depth
|
|
int SchedulerPipelineStageCount_,
|
|
// Accmulator pipeline depth
|
|
int AccumulatorPipelineStageCount_,
|
|
// Number of MMA Bands to be computed in a single FastF32 MMA operation.
|
|
// For BF16 emulation, we have 3 compute matrices, with 9 MMAs forming 5 bands.
|
|
// We can eliminate bands 4 and/or 5 (up to last 3 MMA operations).
|
|
// Valid values are 3, 4, 5
|
|
int NumBandsToCompute_,
|
|
// Scaling factor for decomposed matrices (2^ScalingFactor)
|
|
// 8 for BF16, 11 for TF32
|
|
int ScalingFactor_,
|
|
// Number of UMMA instructions emulated a single stage
|
|
// Ex: Staged16 has 1 FastF32 MMA per stage
|
|
// Should be smaller than K-mode of a single ClusterTile
|
|
int AccPromotionInterval_,
|
|
// ClusterShape for the kernel
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
// The TMEM_LOAD atom to be used for loading local accumulator
|
|
// from TMEM to registers
|
|
class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x
|
|
>
|
|
struct MainloopSm100TmaUmmaWarpSpecializedFastF32 {
|
|
constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_;
|
|
constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_;
|
|
constexpr static int NumBandsToCompute = NumBandsToCompute_;
|
|
constexpr static int ScalingFactor = ScalingFactor_;
|
|
constexpr static int AccPromotionInterval = AccPromotionInterval_;
|
|
constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32;
|
|
using ClusterShape = ClusterShape_;
|
|
using AccumulatorCopyAtom = AccumulatorCopyAtom_;
|
|
using ArchTag = arch::Sm100;
|
|
using Schedule = KernelTmaWarpSpecializedInputTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
|
|
// For backwards compatibility with GemmUniversalAdapter.
|
|
constexpr static int Stages = Load2TransformPipelineStageCount;
|
|
};
|
|
|
|
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100ArrayTmaUmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
constexpr static bool IsOverlappingAccum = false;
|
|
using Schedule = KernelPtrArrayTmaWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
};
|
|
|
|
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
int AccumulatorPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm100;
|
|
constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1;
|
|
using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
};
|
|
|
|
|
|
|
|
// n-buffer in smem, pipelined with Blackwell Fast FP32 kernel with UMMA (HwScaled) and TMA,
|
|
// Warp specialized dynamic schedule
|
|
template<
|
|
// Number of Pipeline stages for
|
|
// MainloopLoad <-> Conversion <-> MainLoad
|
|
int Load2TransformPipelineStageCount_,
|
|
// Number of Pipeline stages for
|
|
// MainloopLoad <-> Conversion <-> MainLoad
|
|
int Transform2MmaPipelineStageCount_,
|
|
// TileScheduler pipeline depth
|
|
int SchedulerPipelineStageCount_,
|
|
// Accmulator pipeline depth
|
|
int AccumulatorPipelineStageCount_,
|
|
// Number of MMA Bands to be computed in a single FastF32 MMA operation.
|
|
// For BF16 emulation, we have 3 compute matrices, with 9 MMAs forming 5 bands.
|
|
// We can eliminate bands 4 and/or 5 (up to last 3 MMA operations).
|
|
// Valid values are 3, 4, 5
|
|
int NumBandsToCompute_,
|
|
// Scaling factor for decomposed matrices (2^ScalingFactor)
|
|
// 8 for BF16, 11 for TF32
|
|
int ScalingFactor_,
|
|
// Number of UMMA instructions emulated a single stage
|
|
// Ex: Staged16 has 1 FastF32 MMA per stage
|
|
// Should be smaller than K-mode of a single ClusterTile
|
|
int AccPromotionInterval_,
|
|
// ClusterShape for the kernel
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
// The TMEM_LOAD atom to be used for loading local accumulator
|
|
// from TMEM to registers
|
|
class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x
|
|
>
|
|
struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 {
|
|
constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_;
|
|
constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_;
|
|
constexpr static int NumBandsToCompute = NumBandsToCompute_;
|
|
constexpr static int ScalingFactor = ScalingFactor_;
|
|
constexpr static int AccPromotionInterval = AccPromotionInterval_;
|
|
constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32;
|
|
using ClusterShape = ClusterShape_;
|
|
using AccumulatorCopyAtom = AccumulatorCopyAtom_;
|
|
using ArchTag = arch::Sm100;
|
|
using Schedule = KernelPtrArrayTmaWarpSpecializedInputTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
|
|
|
// For backwards compatibility with GemmUniversalAdapter.
|
|
constexpr static int Stages = Load2TransformPipelineStageCount;
|
|
};
|
|
|
|
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_,
|
|
class KernelSchedule_
|
|
>
|
|
struct MainloopSm120TmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using Schedule = KernelSchedule_;
|
|
constexpr static int PipelineAsyncMmaStages = 0;
|
|
using ArchTag = arch::Sm120;
|
|
};
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_,
|
|
class KernelSchedule_
|
|
>
|
|
struct MainloopSm120ArrayTmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using Schedule = KernelSchedule_;
|
|
constexpr static int PipelineAsyncMmaStages = 0;
|
|
using ArchTag = arch::Sm120;
|
|
static_assert(
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, Schedule> ||
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, Schedule>,
|
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
|
};
|
|
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_,
|
|
class KernelSchedule_
|
|
>
|
|
struct MainloopSm120TmaWarpSpecializedBlockScaled {
|
|
constexpr static int Stages = Stages_;
|
|
constexpr static int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
using ClusterShape = ClusterShape_;
|
|
using Schedule = KernelSchedule_;
|
|
constexpr static int PipelineAsyncMmaStages = 0;
|
|
using ArchTag = arch::Sm120;
|
|
};
|
|
|
|
template<
|
|
int Stages_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_,
|
|
class KernelSchedule_
|
|
>
|
|
struct MainloopSm120ArrayTmaWarpSpecializedBlockScaled {
|
|
constexpr static int Stages = Stages_;
|
|
constexpr static int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
|
using ClusterShape = ClusterShape_;
|
|
constexpr static int PipelineAsyncMmaStages = 0;
|
|
using Schedule = KernelSchedule_;
|
|
using ArchTag = arch::Sm120;
|
|
|
|
static_assert(cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, Schedule> ||
|
|
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, Schedule>,
|
|
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies.");
|
|
};
|
|
|
|
|
|
template<
|
|
int StagesA_,
|
|
int StagesB_,
|
|
int StagesE_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm120TmaWarpSpecializedSparse {
|
|
constexpr static int StagesA = StagesA_;
|
|
constexpr static int StagesB = StagesB_;
|
|
constexpr static int StagesE = StagesE_;
|
|
constexpr static bool isAsymmetric = (StagesA != StagesB);
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm120;
|
|
using Schedule = KernelTmaWarpSpecializedCooperativeSparseSm120<SchedulerPipelineStageCount_, isAsymmetric>;
|
|
};
|
|
|
|
template<
|
|
int StagesA_,
|
|
int StagesB_,
|
|
int StagesE_,
|
|
int SchedulerPipelineStageCount_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>
|
|
>
|
|
struct MainloopSm120TmaWarpSpecializedSparseBlockScaled {
|
|
constexpr static int StagesA = StagesA_;
|
|
constexpr static int StagesB = StagesB_;
|
|
constexpr static int StagesE = StagesE_;
|
|
constexpr static bool isAsymmetric = (StagesA != StagesB);
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm120;
|
|
using Schedule = KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120<SchedulerPipelineStageCount_, isAsymmetric>;
|
|
};
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass::gemm
|