Support for TMA Epilogue for Group Gemm and add pingpong ptr array & Group Gemm (#1795)

This commit is contained in:
Junkai-Wu
2024-09-11 12:07:31 +08:00
committed by GitHub
parent 21d0534167
commit dbdae514e0
23 changed files with 2356 additions and 344 deletions

View File

@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
// Different configs for pingpong/cooperative
struct CooperativeConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_256,_128,_64>;
using ClusterShape = Shape<_1,_2,_1>;
};
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
struct PingpongConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = Shape<_64,_128,_64>;
using ClusterShape = Shape<_1,_1,_1>;
};
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
template <typename ScheduleConfig>
struct GemmGivenSchedule {
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
@ -261,14 +287,14 @@ bool initialize_block(
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(0);
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(-2);
} else {
scope_max = 8;
scope_min = -8;
scope_max = static_cast<Element>(8);
scope_min = static_cast<Element>(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
@ -351,7 +377,8 @@ void initialize(const Options &options) {
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
template <typename GemmT>
typename GemmT::Arguments args_from_options(const Options &options)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
@ -359,7 +386,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
typename GemmT::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kArray,
{{options.m, options.n, options.k, options.l}},
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
@ -405,20 +432,20 @@ bool verify(const Options &options) {
}
/// Execute a given example GEMM computation
template <typename Gemm>
template <typename GemmT>
int run(Options &options)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
GemmT gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
auto arguments = args_from_options<GemmT>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = GemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@ -510,7 +537,10 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
std::cout << "\n*** Cooperative schedule ***" << std::endl;
run<Gemm>(options);
std::cout << "\n*** Pingpong schedule ***" << std::endl;
run<GemmPingpong>(options);
#endif
return 0;

View File

@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
// Different configs for pingpong/cooperative
struct CooperativeConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_256,_128,_128>;
using ClusterShape = Shape<_2,_2,_1>;
};
struct PingpongConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
};
template <typename ScheduleConfig>
struct GemmGivenSchedule {
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
@ -271,10 +297,10 @@ struct Options {
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = ((rand() % 512) + 1);
m = alignment * ((rand() % 64) + 1);
}
if (n < 1) {
n = ((rand() % 512) + 1);
n = alignment * ((rand() % 64) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
@ -521,7 +547,8 @@ void initialize(const Options &options) {
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
template <typename GemmT>
typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
@ -529,33 +556,49 @@ typename Gemm::Arguments args_from_options(const Options &options, bool host_pro
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::EpilogueOutputOp::Params params;
typename GemmT::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
params = typename Gemm::EpilogueOutputOp::Params(
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
typename Gemm::Arguments arguments;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
arguments = typename GemmT::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
arguments = typename Gemm::Arguments {
arguments = typename GemmT::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
@ -605,20 +648,20 @@ bool verify(const Options &options) {
}
/// Execute a given example GEMM computation
template <typename Gemm>
template <typename GemmT>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
GemmT gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options, host_problem_shapes_available);
auto arguments = args_from_options<GemmT>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = GemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@ -713,8 +756,14 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
std::cout << "\n*** Cooperative schedule ***" << std::endl;
run<Gemm>(options);
std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
run<Gemm>(options, false /*host_problem_shapes_available*/);
std::cout << "\n*** Pingpong schedule ***" << std::endl;
run<GemmPingpong>(options);
std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
#endif
return 0;

View File

@ -32,10 +32,10 @@
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes

View File

@ -274,4 +274,18 @@ struct conditional_template<false, True, False> {
using type = False<U...>;
};
//
// is_any_of
//
/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us
template <typename T, typename... Us>
struct is_any_of {
constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v<T, Us>);
};
/// Is true if and only if T is same as (is_same_v) at least one of the types in Us
template <typename T, typename... Us>
inline constexpr bool is_any_of_v = is_any_of<T, Us...>::value;
} // end namespace cute

View File

@ -71,14 +71,18 @@ sm90_get_tma_dispatch_policy() {
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
constexpr bool ReuseSmem = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
// TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs
constexpr bool DelayTmaStore = is_void_v<ElementC> && !detail::sm90_is_tma_ptr_array_v<Schedule>;
constexpr bool DelayTmaStore = is_void_v<ElementC> && !detail::sm90_is_ptr_array_tma_v<Schedule>;
constexpr int StagesD = cute::min(EpiTiles, 2);
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
: cute::min(EpiTiles, 4);
return cute::conditional_t<detail::sm90_is_tma_ptr_array_v<Schedule>,
Sm90PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>,
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>>{};
if constexpr (detail::sm90_is_ptr_array_tma_v<Schedule>) {
return Sm90PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem,
DelayTmaStore, Schedule::NumEpilogueWarpGroups>{};
}
else {
return Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
}
}
// Returns the smem layout atom to be used for C or D matrix
@ -255,6 +259,9 @@ struct Sm90TmaBuilderImpl {
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
using UnderlyingGmemStrideTypeC = cute::remove_pointer_t<GmemStrideTypeC>;
using UnderlyingGmemStrideTypeD = cute::remove_pointer_t<GmemStrideTypeD>;
using CopyOpS2G = cute::conditional_t<detail::is_im2col_mode<GmemLayoutTagD>,
SM90_TMA_STORE_IM2COL,
SM90_TMA_STORE
@ -267,17 +274,11 @@ struct Sm90TmaBuilderImpl {
// Get the smallest tiled copy we can use to retile the accumulators
using CopyAtomC = Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>;
using FusionDispatchPolicy = Sm90TmaWarpSpecialized<DispatchPolicy::StagesC,
DispatchPolicy::StagesD,
DispatchPolicy::FragmentSize,
DispatchPolicy::ReuseSmemC,
DispatchPolicy::DelayTmaStore>;
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
using FusionCallbacks =
typename CallbacksBuilder<
FusionDispatchPolicy,
DispatchPolicy,
FusionOpOrCallbacks,
TileShape_MNK,
EpilogueTile_MN,
@ -294,11 +295,11 @@ struct Sm90TmaBuilderImpl {
GmemStrideTypeD,
FusionCallbacks,
CopyOpG2S,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<UnderlyingGmemStrideTypeC, ElementC, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_load_op_for_source<UnderlyingGmemStrideTypeC, ElementC>()),
CopyOpS2G,
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>()),
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<UnderlyingGmemStrideTypeD, ElementD, EpilogueTile_MN>()),
decltype(detail::sm90_get_smem_store_op_for_accumulator<UnderlyingGmemStrideTypeD, ElementD>()),
CopyAtomC
>;
};
@ -483,7 +484,7 @@ struct CollectiveBuilder<
FusionOperation,
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule, PtrArrayTmaWarpSpecializedCooperative> >> {
detail::sm90_is_ptr_array_tma_v<Schedule>>> {
private:
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
fusion::get_element_aux_t<FusionOperation>, ElementD_>;

View File

@ -71,6 +71,62 @@ is_im2col() {
|| cute::is_same_v<Stride, cutlass::detail::TagToStrideC_t<cutlass::layout::TensorNDHWC>>;
}
template<class Schedule>
struct sm90_is_ptr_array_tma : cute::false_type {};
template<>
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecializedCooperative> : cute::true_type {};
template<>
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecializedPingpong> : cute::true_type {};
template<>
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecialized> : cute::true_type {};
template<class Schedule>
static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma<Schedule>::value;
template<class Schedule>
struct sm90_is_ptr_array_tma_cooperative : cute::false_type {};
template<>
struct sm90_is_ptr_array_tma_cooperative<PtrArrayTmaWarpSpecializedCooperative> : cute::true_type {};
template<class Schedule>
static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative<Schedule>::value;
template<class Schedule>
struct sm90_is_ptr_array_tma_pingpong : cute::false_type {};
template<>
struct sm90_is_ptr_array_tma_pingpong<PtrArrayTmaWarpSpecializedPingpong> : cute::true_type {};
template<class Schedule>
static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong<Schedule>::value;
template<class DispatchPolicy>
struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {};
template<
int StagesC,
int StagesD,
int FragmentSize,
bool ReuseSmemC,
bool DelayTmaStore,
int NumEpilogueWarpGroups
>
struct sm90_is_ptr_array_tma_dispatch_policy<
Sm90PtrArrayTmaWarpSpecialized<StagesC,
StagesD,
FragmentSize,
ReuseSmemC,
DelayTmaStore,
NumEpilogueWarpGroups>>
: cute::true_type {};
template<class DispatchPolicy>
static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy<DispatchPolicy>::value;
using cutlass::atomic_maximum;
template <class T>
@ -79,14 +135,11 @@ static constexpr int elements_per_access_v = cutlass::sizeof_bits<uint32_t>::val
template <class EpilogueSchedule>
static constexpr bool sm90_is_cooperative_v =
cute::is_base_of_v<cutlass::epilogue::TmaWarpSpecializedCooperative, EpilogueSchedule> ||
cute::is_base_of_v<cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative, EpilogueSchedule>;
template <class EpilogueSchedule>
static constexpr bool sm90_is_tma_ptr_array_v =
cute::is_base_of_v<cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative, EpilogueSchedule>;
sm90_is_ptr_array_tma_cooperative_v<EpilogueSchedule>;
template <class EpilogueSchedule>
static constexpr bool sm90_is_warp_specialized_v =
(!sm90_is_ptr_array_tma_cooperative_v<EpilogueSchedule> && sm90_is_ptr_array_tma_v<EpilogueSchedule>) ||
cute::is_base_of_v<cutlass::epilogue::TmaWarpSpecialized, EpilogueSchedule>;
template <class GmemLayoutTag>
@ -199,7 +252,11 @@ public:
}
CUTLASS_DEVICE auto
load_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, [[maybe_unused]] int32_t const sm_idx) const {
load_init(
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] int32_t sm_count,
[[maybe_unused]] int32_t sm_idx) {
return cute::make_tuple(nullptr);
}
@ -243,7 +300,7 @@ public:
[[maybe_unused]] TensorStorage& shared_tensors,
[[maybe_unused]] TensorMapC const& load_tensormap,
[[maybe_unused]] int subtile_idx=-1,
[[maybe_unused]] bool return_prior_state = false)
[[maybe_unused]] bool wait = false)
{
return load_pipe_producer_state;
}
@ -257,8 +314,12 @@ public:
}
CUTLASS_DEVICE auto
store_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count,
[[maybe_unused]] int32_t const sm_idx) const {
store_init(
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] int32_t sm_count,
[[maybe_unused]] int32_t sm_idx,
[[maybe_unused]] int32_t warp_group_idx) {
return cute::make_tuple(nullptr);
}
@ -369,22 +430,25 @@ public:
// Dummy methods to perform different parts of TMA/Tensormap modifications
template <bool IsLoad>
template <bool IsLoad, class ProblemShapeMNKL>
CUTLASS_DEVICE
void
tensormaps_perform_update(
[[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
[[maybe_unused]] int32_t next_batch) { }
[[maybe_unused]] ProblemShapeMNKL problem_shape,
[[maybe_unused]] int32_t next_batch,
[[maybe_unused]] int32_t warp_group_idx) { }
template <bool IsLoad>
CUTLASS_DEVICE
void
tensormaps_cp_fence_release(
[[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
[[maybe_unused]] uint32_t lane_predicate) { }
[[maybe_unused]] uint32_t lane_predicate,
[[maybe_unused]] int32_t warp_group_idx) { }
template <bool IsLoad>
CUTLASS_DEVICE

View File

@ -44,9 +44,10 @@
#include "cutlass/detail/collective.hpp"
#include "cutlass/detail/layout.hpp"
#include "cutlass/trace.h"
#include "cutlass/cuda_host_adapter.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cuda_host_adapter.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -62,6 +63,7 @@ template <
int FragmentSize_,
bool ReuseSmemC_,
bool DelayTmaStore_,
int NumEpilogueWarpGroups_,
class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K)
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N)
class ElementC_,
@ -78,7 +80,13 @@ template <
class CopyAtomC_
>
class CollectiveEpilogue<
Sm90PtrArrayTmaWarpSpecialized<StagesC_,StagesD_,FragmentSize_,ReuseSmemC_,DelayTmaStore_>,
Sm90PtrArrayTmaWarpSpecialized<StagesC_,
StagesD_,
FragmentSize_,
ReuseSmemC_,
DelayTmaStore_,
NumEpilogueWarpGroups_
>,
CtaTileMNK_,
EpilogueTile_,
ElementC_,
@ -98,7 +106,13 @@ public:
//
// Type Aliases
//
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,StagesD_,FragmentSize_,ReuseSmemC_,DelayTmaStore_>;
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,
StagesD_,
FragmentSize_,
ReuseSmemC_,
DelayTmaStore_,
NumEpilogueWarpGroups_
>;
using CtaTileMNK = CtaTileMNK_;
using EpilogueTile = EpilogueTile_;
using FusionCallbacks = FusionCallbacks_;
@ -201,6 +215,8 @@ public:
(size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof_bits<SmemElementC>::value)) / 8;
constexpr static bool RequiresTransactionBytes = true;
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
// TMA pipeline for storing D
using StorePipeline = cute::conditional_t<ReuseSmemC,
cutlass::PipelineTmaStore<StagesC, StagesD-1>,
@ -219,7 +235,7 @@ public:
struct TensorMapStorage : cute::aligned_struct<128> {
cute::TmaDescriptor smem_tensormap_C;
cute::TmaDescriptor smem_tensormap_D;
cute::array<cute::TmaDescriptor, NumEpilogueWarpGroups> smem_tensormap_D;
} tensormaps;
using PipelineStorage = typename LoadPipeline::SharedStorage;
@ -229,6 +245,8 @@ public:
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideC, StrideC>;
// Host side epilogue arguments
struct Arguments {
typename FusionCallbacks::Arguments thread{};
@ -261,7 +279,9 @@ public:
TMA_D tma_store_d;
cute::TmaDescriptor* tensormaps;
ElementC const** ptr_C;
StrideC dC;
ElementD** ptr_D;
StrideD dD;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
};
@ -275,36 +295,57 @@ public:
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace) {
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1);
auto [M, N, K, mock_L] = problem_shape_MNKL;
// Manage batches/groups through pointers to input matricies
mock_L = 1;
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_L = get<3>(init_shape);
static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D");
InternalStrideC stride_c;
InternalStrideD stride_d;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_c = InternalStrideC{};
stride_d = InternalStrideD{};
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1);
init_M = get<0>(problem_shape_MNKL);
init_N = get<1>(problem_shape_MNKL);
init_L = get<3>(problem_shape_MNKL);
stride_c = args.dC;
stride_d = args.dD;
}
uint32_t transaction_bytes = TmaTransactionBytes;
typename Params::TMA_C tma_load_c = {};
if constexpr (is_source_supported) {
ElementC const* ptr_C_first_batch = reinterpret_cast<ElementC const*>(args.ptr_C);
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC, _0{})));
tma_load_c = make_tma_copy_C_sm90(
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{})));
tma_load_c = make_tma_copy(
CopyOpG2S{},
tensor_c,
take<0,2>(SmemLayoutC{}),
EpilogueTile{});
EpilogueTile{},
_1{});
}
typename Params::TMA_D tma_store_d;
if constexpr (is_destination_supported) {
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD, _0{})));
tma_store_d = make_tma_copy_C_sm90(
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));
tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
take<0,2>(SmemLayoutD{}),
EpilogueTile{});
EpilogueTile{},
_1{});
}
auto fusion_workspace = static_cast<char*>(workspace);
@ -318,7 +359,9 @@ public:
tma_store_d,
tma_descriptor_workspace,
args.ptr_C,
args.dC,
args.ptr_D,
args.dD,
transaction_bytes,
};
}
@ -326,10 +369,11 @@ public:
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr uint32_t NumInputTensors = cute::is_void_v<ElementC> ? 1 : 2;
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
auto descriptors_shape = cute::make_shape(sm_count, Int<NumInputTensors>{});
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
}
template <class ProblemShape>
@ -342,30 +386,40 @@ public:
template <class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
ProblemShape problem_shape,
[[maybe_unused]] Arguments const& args) {
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
if constexpr (is_destination_supported) {
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
}
bool fusion_implementable = true;
if constexpr (not cute::is_void_v<ElementC>) {
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
if (problem_shape.is_host_problem_shape_available()) {
for (int i = 0; i < problem_shape.groups(); ++i) {
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
if constexpr (is_destination_supported) {
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
}
if constexpr (not cute::is_void_v<ElementC>) {
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
}
fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);
}
}
else {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}
@ -414,10 +468,14 @@ public:
}
CUTLASS_DEVICE auto
load_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
load_init(
Params const& params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx) {
// Initialize tma for loading
constexpr bool IsLoad = true;
auto load_tensormaps = tensormaps_init<IsLoad>(params, sm_count, sm_idx);
auto load_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, 0);
return load_tensormaps;
}
@ -426,7 +484,8 @@ public:
class TileShapeMNK,
class TileCoordMNKL,
class TiledMma,
class TensorMapC
class TensorMapC,
__CUTE_REQUIRES(std::is_pointer_v<TensorMapC>)
>
CUTLASS_DEVICE auto
load(
@ -440,7 +499,7 @@ public:
TensorStorage& shared_tensors,
TensorMapC const& load_tensormap,
int subtile_idx=-1,
bool return_prior_state = false) {
bool wait_until_load_finishes = false) {
using namespace cute;
// Indexing variables
@ -478,17 +537,21 @@ public:
auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args);
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
LoadPipelineState last_load_producer_state = load_pipe_producer_state;
// Predication for TMA load (one thread issues TMA load)
bool issue_tma_load = cute::elect_one_sync();
// Acquire the lock for the first stage
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
load_pipeline.producer_acquire(load_pipe_producer_state);
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
// Pre-loop fusion callback entry point
pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load);
auto prior_state = load_pipe_producer_state;
LoadPipelineState prior_state = load_pipe_producer_state;
bool did_load = false;
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
@ -506,15 +569,18 @@ public:
pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load);
// Execute the TMA load for C if needed
if (issue_tma_load && is_C_load_needed) {
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
if (is_C_load_needed) {
if (issue_tma_load) {
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
}
last_load_producer_state = load_pipe_producer_state;
did_load = true;
}
// Commit TMA loads for this stage and release the lock
load_pipeline.producer_commit(load_pipe_producer_state);
prior_state = load_pipe_producer_state;
++load_pipe_producer_state;
}
}
@ -522,17 +588,24 @@ public:
// Post-loop fusion callback entry point
pld_callbacks.end();
if (not return_prior_state) {
return load_pipe_producer_state;
} else {
return prior_state;
if (wait_until_load_finishes && did_load) {
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
}
return load_pipe_producer_state;
}
CUTLASS_DEVICE auto
load_tail(
LoadPipeline load_pipeline,
LoadPipelineState load_pipe_producer_state) {
if (!fusion_callbacks.is_producer_load_needed()) {
return load_pipe_producer_state;
}
bool issue_tma_load = cute::elect_one_sync();
if (issue_tma_load) {
load_pipeline.producer_tail(load_pipe_producer_state);
@ -564,6 +637,7 @@ public:
TensorStorage& shared_tensors,
TensorMapD const& store_tensormap,
int subtile_idx=-1) {
using namespace cute;
using ElementAccumulator = typename AccEngine::value_type;
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
@ -869,11 +943,22 @@ public:
}
CUTLASS_DEVICE auto
store_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
// Initialize tma
constexpr bool IsLoad = false;
auto store_tensormaps = tensormaps_init<IsLoad>(params, sm_count, sm_idx);
return store_tensormaps;
store_init(
Params const& params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx,
int32_t warp_group_idx) {
int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
// Since only one warp issues TMA store, we only need that one warp to initialize tensormaps
if (warp_idx_in_warp_group == 0) {
// Initialize tma
constexpr bool IsLoad = false;
auto store_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx);
return store_tensormaps;
}
TmaDescriptor* null_tma_desc = nullptr;
return cute::make_tuple(null_tma_desc);
}
//
@ -882,53 +967,45 @@ public:
template <bool IsLoad>
CUTLASS_DEVICE auto
tensormaps_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
cute::TmaDescriptor* tma_desc = nullptr;
cute::TmaDescriptor* gmem_tensormap = params.tensormaps;
tensormaps_init(
Params const& params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx,
int32_t warp_group_idx) {
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
Layout desc_layout = make_layout(make_shape(sm_count, Int<NumInputTensors>{}));
Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors)
if constexpr (IsLoad) {
if (not cute::is_void_v<ElementC>) {
tma_desc = &gmem_tensormap[sm_idx];
constexpr int C_tensormap_index = NumEpilogueWarpGroups;
Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{});
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to gmem for modification later
Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gC_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(gC_tensormap));
// Bringing tensormaps from params to smem for modification later
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(sC_tensormap));
}
__syncwarp();
return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index));
}
} else {
int const offset_Ddesc = cute::is_void_v<ElementC> ? 0 : sm_count;
tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc];
TmaDescriptor* null_tma_desc = nullptr;
return cute::make_tuple(null_tma_desc);
}
else {
Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{});
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to gmem for modification later
Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gD_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(gD_tensormap));
// Bringing tensormaps from params to smem for modification later
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(sD_tensormap));
}
__syncwarp();
return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx));
}
return cute::make_tuple(tma_desc);
}
// Bringing tensormaps to smem (to be done by single thread)
template <bool IsLoad>
CUTLASS_DEVICE
void
tensormaps_fetch_to_smem(
TensorMapStorage& shared_tensormap,
cute::TmaDescriptor const* tensormap) const {
if constexpr (IsLoad) {
if (not cute::is_void_v<ElementC>) {
Tensor gC_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{});
Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(gC_tensormap), recast<uint128_t>(sC_tensormap));
}
} else {
Tensor gD_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{});
Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(gD_tensormap), recast<uint128_t>(sD_tensormap));
}
cp_async_fence();
cp_async_wait<0>();
}
// Replace address for the global tensor (to be done by single thread)
@ -936,35 +1013,99 @@ public:
CUTLASS_DEVICE
void
tensormaps_replace_global_address(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
Params const& params,
int32_t next_batch) {
int32_t next_batch,
int32_t warp_group_idx) {
// Replacing global_address for the next batch
if constexpr (IsLoad) {
if (not cute::is_void_v<ElementC>) {
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C,
if constexpr (is_source_supported) {
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C,
params.ptr_C[next_batch]);
}
} else {
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D,
}
else if constexpr (is_destination_supported) {
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
params.ptr_D[next_batch]);
}
}
template <bool IsLoad>
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
template <bool IsLoad, class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormaps,
Params const& params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl,
int32_t warp_group_idx) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
// Only consider dimensions and strides that we need to recalculate and replace for each group
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
static_assert(TensorRank == Int<3>{},
"Descriptor modification for global dims & strides expects rank as 3.");
cute::array<uint32_t, TensorRank> prob_shape = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride = {0,0,0};
if constexpr (IsLoad) {
if constexpr (is_source_supported) {
ElementC const* ptr_C = nullptr;
Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group]));
cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c,
prob_shape, prob_stride);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride) {
stride = (stride * sizeof_bits_v<ElementC>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C,
prob_shape,
prob_stride);
}
}
else if constexpr (is_destination_supported) {
ElementD const* ptr_D = nullptr;
// tma_store_c should be a gmem_tensor, second argument should be a stride
Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group]));
cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d,
prob_shape, prob_stride);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride) {
stride = (stride * sizeof_bits_v<ElementD>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
prob_shape,
prob_stride);
}
}
template <bool IsLoad, class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_perform_update(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
Params const& params,
cute::TmaDescriptor const* tensormap,
int32_t next_batch) {
ProblemShape_MNKL problem_shape_mnkl,
int32_t next_batch,
int32_t warp_group_idx) {
if (cute::elect_one_sync()) {
// Bringing tensormaps to smem
tensormaps_fetch_to_smem<IsLoad>(shared_tensormap, tensormap);
// Replacing global_address for the next batch
tensormaps_replace_global_address<IsLoad>(shared_tensormap, params, next_batch);
tensormaps_replace_global_address<IsLoad>(shared_tensormaps, params, next_batch, warp_group_idx);
if constexpr (IsGroupedGemmKernel) {
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties<IsLoad>(
shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx);
}
}
}
@ -972,16 +1113,18 @@ public:
CUTLASS_DEVICE
void
tensormaps_cp_fence_release(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
cute::TmaDescriptor const* tensormap,
[[maybe_unused]] uint32_t lane_predicate) {
[[maybe_unused]] uint32_t lane_predicate,
int32_t warp_group_idx = 0) {
// Entire warp must do this (ie its aligned)
if constexpr (IsLoad) {
if (not cute::is_void_v<ElementC>) {
tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C);
if constexpr (is_source_supported) {
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C);
}
} else {
tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D);
}
else if constexpr (is_destination_supported) {
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]);
}
}
@ -990,10 +1133,11 @@ public:
void
tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) {
if constexpr (IsLoad) {
if (not cute::is_void_v<ElementC>) {
if constexpr (not cute::is_void_v<ElementC>) {
cute::tma_descriptor_fence_acquire(tensormap);
}
} else {
}
else {
cute::tma_descriptor_fence_acquire(tensormap);
}
}

View File

@ -51,7 +51,21 @@ struct PtrArrayNoSmemWarpSpecialized {};
struct PtrArrayPlanarComplexNoSmemWarpSpecialized {};
struct TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperative {};
struct PtrArrayTmaWarpSpecializedCooperative {};
struct PtrArrayTmaWarpSpecializedCooperative {
static constexpr int NumEpilogueWarpGroups = 2;
};
// Standard warp specialized epilogue
struct PtrArrayTmaWarpSpecialized {
static constexpr int NumEpilogueWarpGroups = 1;
};
// Pingpong kernel epilogue
struct PtrArrayTmaWarpSpecializedPingpong {
static constexpr int NumEpilogueWarpGroups = 2;
};
// DEPRECATED schedules, will be removed in next release
struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {};
@ -151,7 +165,8 @@ template<
int StagesD_,
int FragmentSize_,
bool ReuseSmemC_,
bool DelayTmaStore_
bool DelayTmaStore_,
int NumEpilogueWarpGroups_
>
struct Sm90PtrArrayTmaWarpSpecialized {
constexpr static int StagesC = StagesC_;
@ -159,6 +174,7 @@ struct Sm90PtrArrayTmaWarpSpecialized {
constexpr static int FragmentSize = FragmentSize_;
constexpr static bool ReuseSmemC = ReuseSmemC_;
constexpr static bool DelayTmaStore = DelayTmaStore_;
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
};
// DEPRECATED policies, will be removed in next release

View File

@ -32,6 +32,7 @@
#pragma once
#include <cutlass/numeric_conversion.h>
#include <cutlass/layout/matrix.h>
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -179,6 +179,89 @@ struct FusionCallbacks<
/////////////////////////////////////////////////////////////////////////////////////////////////
// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch
template<
class ElementOutput,
class ElementCompute,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinearCombinationPtrArray =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int>>, // beta
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int>>, // alpha
Sm90AccFetch // acc
>
>;
template <
int StagesC,
int StagesD,
int FragmentSize,
bool ReuseSmemC,
bool DelayTmaStore,
int NumEpilogueWarpGroups,
class ElementOutput,
class ElementCompute,
class ElementSource,
class ElementScalar,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
class EpilogueTile
>
struct FusionCallbacks<
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
StagesD,
FragmentSize,
ReuseSmemC,
DelayTmaStore,
NumEpilogueWarpGroups
>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
using Impl = Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
struct Arguments {
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
ElementScalar const* alpha_ptr = nullptr;
ElementScalar const* beta_ptr = nullptr;
ElementScalar const* const* alpha_ptr_array = nullptr;
ElementScalar const* const* beta_ptr_array = nullptr;
using StrideAlpha = Stride<_0,_0,int>;
using StrideBeta = Stride<_0,_0,int>;
StrideAlpha dAlpha = {_0{}, _0{}, 0};
StrideBeta dBeta = {_0{}, _0{}, 0};
operator typename Impl::Arguments() const {
return
{ // ternary op : beta * C + (alpha * acc)
{{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta
{}, // leaf args : C
{ // binary op : alpha * acc
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
{}, // leaf args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
}; // end ternary op
}
};
// Ctor inheritance
using Impl::Impl;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// D = activation(alpha * acc + beta * C)
template<
template <class> class ActivationFn,

View File

@ -37,6 +37,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cute/tensor.hpp"
#include "sm90_visitor_tma_warpspecialized.hpp"
@ -514,7 +515,8 @@ private:
if (params_ptr->scalar_ptrs[0] != nullptr) {
scalar = params_ptr->scalar_ptrs[0][l_offset];
} else {
}
else {
// batch stride is ignored for nullptr fallback
scalar = params_ptr->scalars[0];
}
@ -541,6 +543,169 @@ private:
}
};
// Scalar broadcast
// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors
template<
class Element,
class StrideMNL = Stride<_0,_0,_0>,
int BroadcastCount = 1,
template <class> class ReductionFn = multiplies
>
struct Sm90ScalarBroadcastPtrArray {
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{});
struct SharedStorage { };
struct Arguments {
Element scalars[BroadcastCount] = {};
Element const* scalar_ptrs[BroadcastCount] = {};
Element const* const* scalar_ptr_arrays[BroadcastCount] = {};
StrideMNL dScalar[BroadcastCount] = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter *cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
// producer load is needed if Element is not void and we have multiple scalars
return !cute::is_void_v<Element> and size<2>(params_ptr->dScalar[0]) != 0;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
// This must be called after update_scalar is called
CUTLASS_DEVICE bool
is_zero() const {
return scalar == Element(0);
}
CUTLASS_HOST_DEVICE
Sm90ScalarBroadcastPtrArray() { }
CUTLASS_HOST_DEVICE
Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) {
// Get the scalar for non-batched broadcast
if (size<2>(params_ptr->dScalar[0]) == 0) {
update_scalar();
}
}
Element scalar;
Params const* params_ptr;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
// Get the scalar for batched broadcast
if (get<2>(params_ptr->dScalar[0]) != 0) {
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
update_scalar(l_coord);
}
return EmptyProducerLoadCallbacks{};
}
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(Element scalar)
: scalar(scalar) {}
Element scalar;
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_scalar;
frg_scalar.fill(scalar);
return frg_scalar;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
// Get the scalar for batched broadcast
if (get<2>(params_ptr->dScalar[0]) != 0) {
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
update_scalar(l_coord);
}
return ConsumerStoreCallbacks(scalar);
}
private:
CUTLASS_DEVICE void
update_scalar(int l_coord = 0) {
int l_offset = l_coord * size<2>(params_ptr->dScalar[0]);
if (params_ptr->scalar_ptr_arrays[0] != nullptr) {
scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]);
}
else if (params_ptr->scalar_ptrs[0] != nullptr) {
scalar = params_ptr->scalar_ptrs[0][l_offset];
}
else {
// batch stride is ignored for nullptr fallback
scalar = params_ptr->scalars[0];
}
// Do reduction over multiple broadcasts if necessary
ReductionFn<Element> reduction_fn;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < BroadcastCount; ++i) {
if (params_ptr->scalar_ptr_arrays[i] != nullptr) {
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset]));
}
if (params_ptr->scalar_ptrs[i] != nullptr) {
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
}
else {
// batch stride is ignored for nullptr fallback
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {

View File

@ -31,6 +31,10 @@
#pragma once
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/pipeline/sm90_pipeline.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
@ -177,10 +181,12 @@ struct CollectiveBuilder<
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
(cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecialized,
KernelTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedPingpong,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
static_assert(is_static<TileShape_MNK>::value);
@ -191,10 +197,12 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
"KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@ -203,8 +211,10 @@ struct CollectiveBuilder<
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> || IsArrayOfPointersGemm,
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
@ -218,7 +228,10 @@ struct CollectiveBuilder<
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
@ -505,10 +518,12 @@ struct CollectiveBuilder<
StageCountType,
KernelScheduleType,
cute::enable_if_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>
cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedFP8FastAccum,
KernelTmaWarpSpecializedPingpongFP8FastAccum,
KernelTmaWarpSpecializedCooperativeFP8FastAccum,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum>>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
@ -526,10 +541,15 @@ struct CollectiveBuilder<
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum>;
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8FastAccum,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative, Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
@ -542,7 +562,11 @@ struct CollectiveBuilder<
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,

View File

@ -623,56 +623,42 @@ struct CollectiveMma<
//
CUTLASS_DEVICE auto
tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const {
tensormaps_init(
Params const& mainloop_params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx) {
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to gmem for modification later
// Bringing tensormaps from params to smem for modification later
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(gA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(gB_tensormap));
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
}
__syncwarp();
return cute::make_tuple(tma_desc_a, tma_desc_b);
}
// Bringing tensormaps to smem (to be done by single thread)
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_fetch_to_smem(
TensorMapStorage& shared_tensormap,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) const {
Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(gA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(gB_tensormap), recast<uint128_t>(sB_tensormap));
cp_async_fence();
cp_async_wait<0>();
}
// Replace address for the global tensor (to be done by single thread)
CUTLASS_DEVICE
void
tensormaps_replace_global_address(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_batch) {
// Replacing global_address for the next batch
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A,
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
mainloop_params.ptr_A[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B,
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
mainloop_params.ptr_B[next_batch]);
}
@ -681,7 +667,7 @@ struct CollectiveMma<
CUTLASS_DEVICE
void
tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
@ -716,10 +702,10 @@ struct CollectiveMma<
stride = (stride * sizeof_bits_v<InternalElementB>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
prob_shape_A,
prob_stride_A);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B,
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
prob_shape_B,
prob_stride_B);
}
@ -728,21 +714,19 @@ struct CollectiveMma<
CUTLASS_DEVICE
void
tensormaps_perform_update(
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
ProblemShape_MNKL problem_shape_mnkl,
int32_t next_batch) {
if (cute::elect_one_sync()) {
// Bringing tensormaps to smem
tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps);
// Replacing global_address for the next batch
tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch);
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
if constexpr (IsGroupedGemmKernel) {
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties(shared_tensormap,
tensormaps_replace_global_tensor_properties(shared_tensormaps,
mainloop_params, next_batch, problem_shape_mnkl);
}
}
@ -752,11 +736,11 @@ struct CollectiveMma<
CUTLASS_DEVICE
void
tensormaps_cp_fence_release (
TensorMapStorage& shared_tensormap,
TensorMapStorage& shared_tensormaps,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
// Entire warp must do this (i.e. it's aligned)
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B);
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
}
// The entire warp must call this function collectively (that is, the instructions are aligned)

View File

@ -98,6 +98,7 @@ struct KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpong { };
//////////////////////////////////////////////////////////////////////////////
@ -111,6 +112,7 @@ struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum : KernelPtrArrayTmaWarpSpecializedPingpong { };
// Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -286,8 +288,9 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
static_assert(
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
};
//////////////////////////////////////////////////////////////////////////////

View File

@ -61,5 +61,6 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
////////////////////////////////////////////////////////////////////////////////

View File

@ -46,6 +46,8 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
@ -73,6 +75,9 @@ public:
using ProblemShape = ProblemShape_;
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
static_assert(cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>);
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
@ -119,8 +124,9 @@ public:
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{}));
static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
@ -215,11 +221,11 @@ public:
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count);
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count);
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
@ -275,9 +281,6 @@ public:
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
@ -286,6 +289,9 @@ public:
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
@ -363,6 +369,12 @@ public:
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(NumMmaWarpGroups == 2, "Cooperative kernels currently only support NumMmaWarpGroups == 2");
if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v<typename CollectiveEpilogue::DispatchPolicy>) {
static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups,
"Tiled MmA does not match expected warp groups performing the epilogue");
}
static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -391,7 +403,8 @@ public:
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto warp_group_idx = canonical_warp_group_idx();
auto warp_group_role = WarpGroupRole(warp_group_idx);
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
@ -466,7 +479,9 @@ public:
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
@ -484,7 +499,7 @@ public:
}
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
@ -510,7 +525,7 @@ public:
int32_t const sm_count = params.hw_info.sm_count;
// Fetch a copy of tensormaps for the CTA
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, sm_count, sm_idx);
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx);
// Update tensormap for the initial batch for the CTA
if (work_tile_info.is_valid()) {
@ -578,7 +593,7 @@ public:
if (work_tile_info.is_valid() && did_batch_change) {
curr_batch = next_batch;
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), Int<1>{});
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
}
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
// Since this state is waiting for loads to finish, it must start in the inverted phase.
@ -610,7 +625,7 @@ public:
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
int32_t const sm_count = params.hw_info.sm_count;
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, sm_count, sm_idx));
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx));
bool did_batch_change = true;
constexpr bool IsEpiLoad = true;
@ -620,23 +635,26 @@ public:
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_load_tensormap,
work_tile_info.L_idx
problem_shape_MNKL,
work_tile_info.L_idx,
0
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate);
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
}
load_order_barrier.wait();
while (work_tile_info.is_valid()) {
int32_t curr_batch = work_tile_info.L_idx;
bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler);
// Get next work tile
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
if (compute_epilogue) {
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
@ -649,6 +667,8 @@ public:
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
}
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
@ -660,36 +680,34 @@ public:
shared_storage.tensors.epilogue,
epi_load_tensormap,
work_tile_info.reduction_subtile_idx(),
true // return state prior to last advance
wait
);
}
// Get next work tile
work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
// Wait for TMA load to finish before updating
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
{epi_load_pipe_producer_state.index(), !epi_load_pipe_producer_state.phase(), epi_load_pipe_producer_state.count()};
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
epi_load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
// tensormap update
{
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_load_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
0
);
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_load_tensormap,
work_tile_info.L_idx
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate);
}
if(compute_epilogue) {
epi_load_pipe_producer_state.advance(1);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
}
}
} // Scheduler work fetch loop
@ -702,32 +720,43 @@ public:
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
int32_t const sm_count = params.hw_info.sm_count;
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
bool do_store_tail = false;
// Get a copy of tensormaps
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, sm_count, sm_idx));
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx));
bool did_batch_change = true;
constexpr bool IsEpiLoad = false;
if (work_tile_info.is_valid()) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
work_tile_info.L_idx
);
if (warp_idx_in_warp_group == 0) {
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate);
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
consumer_warp_group_idx
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
while (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
int32_t curr_batch = work_tile_info.L_idx;
@ -743,6 +772,10 @@ public:
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
static_assert(cute::is_any_of_v<TileScheduler,
detail::PersistentTileSchedulerSm90Group<ProblemShape>,
detail::PersistentTileSchedulerSm90>);
if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
collective_mainloop.mma(
mainloop_pipeline,
@ -764,18 +797,16 @@ public:
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
// Perform reduction across splits, if needed
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx);
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
if (did_batch_change) {
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
}
if (did_batch_change) {
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
@ -804,20 +835,31 @@ public:
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
work_tile_info.L_idx
);
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
if (warp_idx_in_warp_group == 0) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
consumer_warp_group_idx
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
} // Scheduler work fetch loop
// Cooperative only needs TMA to complete at the very end of the kernel
if (do_store_tail) {
collective_epilogue.store_tail(
epi_load_pipeline,
@ -829,7 +871,6 @@ public:
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,949 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/workspace.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal_decl.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_
>
class GemmUniversal<
ProblemShape_,
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
static_assert(cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>);
static constexpr bool IsGdcEnabled = false;
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using InternalStrideA = typename CollectiveMainloop::InternalStrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using InternalStrideB = typename CollectiveMainloop::InternalStrideB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using Schedule = typename DispatchPolicy::Schedule;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using InternalStrideC = typename CollectiveEpilogue::InternalStrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using InternalStrideD = typename CollectiveEpilogue::InternalStrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
static_assert(cute::is_void_v<TileScheduler_>,
"Ptr-Array Pingpong and Grouped Gemm Pingpong kernel only supports the default scheduler.");
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
using TileScheduler = cute::conditional_t<IsGroupedGemmKernel,
typename detail::TileSchedulerSelector<
GroupScheduler, ArchTag,
TileShape, ClusterShape,
ProblemShape>::Scheduler,
typename detail::TileSchedulerSelector<
void, ArchTag, TileShape, ClusterShape>::Scheduler>;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
using MathWarpGroupOrderBarrierSharedStorage = cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage<
MathWarpGroupOrderBarrier::SequenceDepth,
MathWarpGroupOrderBarrier::SequenceLength>;
// Kernel level shared memory storage
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
} pipelines;
struct TensorMapStorage : cute::aligned_struct<128> {
using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage;
using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage;
alignas(128) MainloopTensorMapStorage mainloop;
alignas(128) EpilogueTensorMapStorage epilogue;
} tensormaps;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void* workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
ProblemShape problem_shapes = args.problem_shape;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
// subtile will not be used, therefore separate reduction will not be enabled.
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
TileSchedulerParams scheduler;
if constexpr (IsGroupedGemmKernel) {
scheduler = TileScheduler::to_underlying_arguments(
problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
}
else {
scheduler = TileScheduler::to_underlying_arguments(
problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
}
return {
args.mode,
problem_shapes,
CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace),
hw_info,
scheduler,
workspace
};
}
static bool
can_implement(Arguments const& args) {
bool implementable = true;
if constexpr (IsGroupedGemmKernel) {
// Group GEMM currently only supports rank-3 problem shapes
implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3);
} else {
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_size = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter* cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter);
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3
get_grid_shape(Params const& params) {
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM;
dim3 grid_shape;
if constexpr (IsGroupedGemmKernel) {
grid_shape = TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
}
else {
grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args);
}
return grid_shape;
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(size(TiledMma{}) == 128, "Pingpong kernel must have TiledMMA operating using 128 threads.");
static_assert(NumMmaWarpGroups == 2, "Pingpong kernels currently only support NumMmaWarpGroups == 2");
if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v<typename CollectiveEpilogue::DispatchPolicy>) {
static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups,
"Tiled MmA does not match expected warp groups performing the epilogue");
}
static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
enum class ProducerWarpRole {
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_idx = canonical_warp_group_idx();
auto warp_group_role = WarpGroupRole(warp_group_idx);
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Note: Tma Descriptor Prefetch (from either const or param) is not applicable here
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
}
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id = warp_group_idx - static_cast<int>(WarpGroupRole::Consumer0);
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [] () {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return [] () { cute::cluster_wait(); };
}
else {
__syncthreads();
return [] () {}; // do nothing
}
} ();
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
if (not work_tile_info.is_valid()) {
// When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups
return;
}
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
if (warp_group_role == WarpGroupRole::Consumer1) {
// Advance 2nd Math WG to the next work tile for the startup
const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
if (!work_tile_info.is_valid()) {
return;
}
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop) {
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
int32_t const mock_l_coord = 0;
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
int32_t const sm_count = params.hw_info.sm_count;
// Fetch a copy of tensormaps for the CTA
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx);
// Update tensormap for the initial batch for the CTA
if (work_tile_info.is_valid()) {
collective_mainloop.tensormaps_perform_update(
shared_storage.tensormaps.mainloop,
params.mainloop,
input_tensormaps,
problem_shape_MNKL,
curr_batch
);
// Ensure warp is converged before issuing tensormap fence release
__syncwarp();
// Entire warp must do this (i.e. it's aligned)
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
}
bool do_load_order_arrive = true;
bool did_batch_change = true;
while (work_tile_info.is_valid()) {
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord);
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
if (did_batch_change) {
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
}
collective_mainloop.load(
params.mainloop,
mainloop_pipeline,
mainloop_pipe_producer_state,
load_inputs,
input_tensormaps,
blk_coord,
k_tile_iter, work_k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
// Update starting pipeline state for the next tile
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
mainloop_pipe_producer_state.advance(work_k_tile_count - 1);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx
did_batch_change = next_batch != curr_batch;
if (work_tile_info.is_valid() && did_batch_change) {
curr_batch = next_batch;
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
}
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
// Since this state is waiting for loads to finish, it must start in the inverted phase.
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state =
{mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()};
mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state);
collective_mainloop.tensormaps_perform_update(
shared_storage.tensormaps.mainloop,
params.mainloop,
input_tensormaps,
problem_shape_MNKL,
curr_batch
);
// Ensure warp is converged before issuing tensor replace
__syncwarp();
// Entire warp must do this (i.e. it's aligned)
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
}
// Advance the producer state for the last remaining stage that was being waited for above
mainloop_pipe_producer_state.advance(1);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
int32_t const sm_count = params.hw_info.sm_count;
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx));
bool did_batch_change = true;
constexpr bool IsEpiLoad = true;
if (work_tile_info.is_valid()) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_load_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
0
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
}
load_order_barrier.wait();
while (work_tile_info.is_valid()) {
int32_t curr_batch = work_tile_info.L_idx;
// Get next work tile
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
if (did_batch_change) {
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
}
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
tiled_mma,
lane_idx,
shared_storage.tensors.epilogue,
epi_load_tensormap,
work_tile_info.reduction_subtile_idx(),
wait
);
}
work_tile_info = next_work_tile_info;
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
// tensormap update
{
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_load_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
0
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
}
}
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
int32_t const sm_count = params.hw_info.sm_count;
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
bool do_store_tail = false;
// Get a copy of tensormaps
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx));
bool did_batch_change = true;
constexpr bool IsEpiLoad = false;
if (work_tile_info.is_valid()) {
if (warp_idx_in_warp_group == 0) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
consumer_warp_group_idx
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
while (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
int32_t curr_batch = work_tile_info.L_idx;
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
static_assert(cute::is_any_of_v<TileScheduler,
detail::PersistentTileSchedulerSm90Group<ProblemShape>,
detail::PersistentTileSchedulerSm90>);
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
math_wg_order_barrier.wait();
collective_mainloop.mma(
mainloop_pipeline,
mainloop_pipe_consumer_state,
accumulators,
work_k_tile_count,
mma_thread_idx,
shared_storage.tensors.mainloop,
params.mainloop
);
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(
mainloop_pipeline,
mainloop_pipe_consumer_state,
work_k_tile_count
);
math_wg_order_barrier.wait();
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Perform reduction across splits, if needed
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx);
if (did_batch_change) {
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline,
epi_load_pipe_consumer_state,
epi_store_pipeline,
epi_store_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
accumulators,
tiled_mma,
mma_thread_idx,
shared_storage.tensors.epilogue,
epi_store_tensormap,
work_tile_info.reduction_subtile_idx()
);
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
do_store_tail = true;
}
// Get next work tile
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
// Skip a tile for pingpong
if (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
mainloop_pipe_consumer_state.advance(work_k_tile_count);
// Go to next tile
auto next_next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_next_work_tile_info;
}
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
}
if (warp_idx_in_warp_group == 0) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
shared_storage.tensormaps.epilogue,
params.epilogue,
epi_store_tensormap,
problem_shape_MNKL,
work_tile_info.L_idx,
consumer_warp_group_idx
);
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] =
collective_epilogue.store_tail(
epi_load_pipeline,
epi_load_pipe_consumer_state,
epi_store_pipeline,
epi_store_pipe_producer_state
);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@ -338,12 +338,14 @@ cutlass_test_unit_add_executable(
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu
)
# Group Gemm test
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu
)
# Fused epilogue tests

View File

@ -1005,7 +1005,7 @@ struct HostCollectiveEpilogue {
stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t<LayoutTagAux>{}, cute::make_shape(M, N, 1));
}
static_assert(!IsGroupGemm or (IsGroupGemm and IsAuxOutEnabled));
static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled));
if constexpr (IsAuxOutEnabled) {
for (int32_t i = 0; i < L; ++i) {
@ -1323,8 +1323,16 @@ struct HostCollectiveEpilogue {
cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch]));
auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()),
cute::make_layout(cute::make_shape(M, cute::_1{})));
auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensors_Aux[batch].host_data() : references_Aux[batch].host_data()),
cute::make_layout(cute::make_shape(M, N, 1), stride_Aux));
auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux);
auto Aux = [&]() {
auto ptr = recast_ptr<ElementAux>(nullptr);
if (IsAuxInEnabled) {
ptr = detail::make_iterator(tensors_Aux[batch].host_data());
} else if (IsAuxOutEnabled) {
ptr = detail::make_iterator(references_Aux[batch].host_data());
}
return cute::make_tensor(ptr, Aux_layout);
}();
auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()),
cute::make_layout(cute::make_shape(M, cute::_1{})));
auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()),

View File

@ -78,7 +78,69 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm, 128x128x64_2x2x1_direct_store) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
@ -115,6 +177,8 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

View File

@ -0,0 +1,184 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Ptr-Array Ping-pong scheduler GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x_ptr_array.hpp"
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
using namespace cute;
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1_direct_store) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

View File

@ -115,9 +115,11 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) {
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_direct_store) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
@ -173,6 +175,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(TestAll<Gemm>(1.0, 1.0));
EXPECT_TRUE(TestAll<Gemm>(1.0, 0.0));
}

View File

@ -0,0 +1,182 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide Ptr-Array GEMM interface
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../../common/cutlass_unit_test.h"
#include "gemm_testbed_3x_ptr_array.hpp"
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
using namespace cute;
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool result = TestAll<Gemm>(1.0, 1.0);
EXPECT_TRUE(result);
result = TestAll<Gemm>(1.0, 0.0);
EXPECT_TRUE(result);
}
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1_direct_store) {
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue
>;
using namespace test::gemm::device;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(TestAll<Gemm>(1.0, 1.0));
EXPECT_TRUE(TestAll<Gemm>(1.0, 0.0));
}
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)