Support for TMA Epilogue for Group Gemm and add pingpong ptr array & Group Gemm (#1795)
This commit is contained in:
@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
// Different configs for pingpong/cooperative
|
||||
struct CooperativeConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256,_128,_64>;
|
||||
using ClusterShape = Shape<_1,_2,_1>;
|
||||
};
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
struct PingpongConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = Shape<_64,_128,_64>;
|
||||
using ClusterShape = Shape<_1,_1,_1>;
|
||||
};
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
template <typename ScheduleConfig>
|
||||
struct GemmGivenSchedule {
|
||||
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
|
||||
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
|
||||
|
||||
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
|
||||
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
@ -261,14 +287,14 @@ bool initialize_block(
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
@ -351,7 +377,8 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
template <typename GemmT>
|
||||
typename GemmT::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
@ -359,7 +386,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
typename GemmT::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kArray,
|
||||
{{options.m, options.n, options.k, options.l}},
|
||||
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
|
||||
@ -405,20 +432,20 @@ bool verify(const Options &options) {
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
template <typename GemmT>
|
||||
int run(Options &options)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
GemmT gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
auto arguments = args_from_options<GemmT>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = GemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
@ -510,7 +537,10 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "\n*** Cooperative schedule ***" << std::endl;
|
||||
run<Gemm>(options);
|
||||
std::cout << "\n*** Pingpong schedule ***" << std::endl;
|
||||
run<GemmPingpong>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // A
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
// Different configs for pingpong/cooperative
|
||||
struct CooperativeConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256,_128,_128>;
|
||||
using ClusterShape = Shape<_2,_2,_1>;
|
||||
};
|
||||
|
||||
struct PingpongConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = Shape<_128,_128,_128>;
|
||||
using ClusterShape = Shape<_2,_1,_1>;
|
||||
};
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GemmGivenSchedule {
|
||||
using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
EpilogueSchedule,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel = GemmGivenSchedule<CooperativeConfig>::GemmKernel;
|
||||
using Gemm = GemmGivenSchedule<CooperativeConfig>::Gemm;
|
||||
|
||||
using GemmKernelPingpong = GemmGivenSchedule<PingpongConfig>::GemmKernel;
|
||||
using GemmPingpong = GemmGivenSchedule<PingpongConfig>::Gemm;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
@ -271,10 +297,10 @@ struct Options {
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = ((rand() % 512) + 1);
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = ((rand() % 512) + 1);
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
@ -521,7 +547,8 @@ void initialize(const Options &options) {
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
|
||||
template <typename GemmT>
|
||||
typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
@ -529,33 +556,49 @@ typename Gemm::Arguments args_from_options(const Options &options, bool host_pro
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::EpilogueOutputOp::Params params;
|
||||
typename GemmT::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(
|
||||
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
arguments = typename GemmT::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
arguments = typename GemmT::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
@ -605,20 +648,20 @@ bool verify(const Options &options) {
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
template <typename GemmT>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
GemmT gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options, host_problem_shapes_available);
|
||||
auto arguments = args_from_options<GemmT>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = GemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
@ -713,8 +756,14 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "\n*** Cooperative schedule ***" << std::endl;
|
||||
run<Gemm>(options);
|
||||
std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl;
|
||||
run<Gemm>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "\n*** Pingpong schedule ***" << std::endl;
|
||||
run<GemmPingpong>(options);
|
||||
std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl;
|
||||
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -32,10 +32,10 @@
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
|
||||
@ -274,4 +274,18 @@ struct conditional_template<false, True, False> {
|
||||
using type = False<U...>;
|
||||
};
|
||||
|
||||
//
|
||||
// is_any_of
|
||||
//
|
||||
|
||||
/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us
|
||||
template <typename T, typename... Us>
|
||||
struct is_any_of {
|
||||
constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v<T, Us>);
|
||||
};
|
||||
|
||||
/// Is true if and only if T is same as (is_same_v) at least one of the types in Us
|
||||
template <typename T, typename... Us>
|
||||
inline constexpr bool is_any_of_v = is_any_of<T, Us...>::value;
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -71,14 +71,18 @@ sm90_get_tma_dispatch_policy() {
|
||||
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
|
||||
constexpr bool ReuseSmem = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
|
||||
// TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs
|
||||
constexpr bool DelayTmaStore = is_void_v<ElementC> && !detail::sm90_is_tma_ptr_array_v<Schedule>;
|
||||
constexpr bool DelayTmaStore = is_void_v<ElementC> && !detail::sm90_is_ptr_array_tma_v<Schedule>;
|
||||
constexpr int StagesD = cute::min(EpiTiles, 2);
|
||||
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
|
||||
: cute::min(EpiTiles, 4);
|
||||
|
||||
return cute::conditional_t<detail::sm90_is_tma_ptr_array_v<Schedule>,
|
||||
Sm90PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>,
|
||||
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>>{};
|
||||
if constexpr (detail::sm90_is_ptr_array_tma_v<Schedule>) {
|
||||
return Sm90PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem,
|
||||
DelayTmaStore, Schedule::NumEpilogueWarpGroups>{};
|
||||
}
|
||||
else {
|
||||
return Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the smem layout atom to be used for C or D matrix
|
||||
@ -255,6 +259,9 @@ struct Sm90TmaBuilderImpl {
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
using UnderlyingGmemStrideTypeC = cute::remove_pointer_t<GmemStrideTypeC>;
|
||||
using UnderlyingGmemStrideTypeD = cute::remove_pointer_t<GmemStrideTypeD>;
|
||||
|
||||
using CopyOpS2G = cute::conditional_t<detail::is_im2col_mode<GmemLayoutTagD>,
|
||||
SM90_TMA_STORE_IM2COL,
|
||||
SM90_TMA_STORE
|
||||
@ -267,17 +274,11 @@ struct Sm90TmaBuilderImpl {
|
||||
// Get the smallest tiled copy we can use to retile the accumulators
|
||||
using CopyAtomC = Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>;
|
||||
|
||||
using FusionDispatchPolicy = Sm90TmaWarpSpecialized<DispatchPolicy::StagesC,
|
||||
DispatchPolicy::StagesD,
|
||||
DispatchPolicy::FragmentSize,
|
||||
DispatchPolicy::ReuseSmemC,
|
||||
DispatchPolicy::DelayTmaStore>;
|
||||
|
||||
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
|
||||
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
|
||||
using FusionCallbacks =
|
||||
typename CallbacksBuilder<
|
||||
FusionDispatchPolicy,
|
||||
DispatchPolicy,
|
||||
FusionOpOrCallbacks,
|
||||
TileShape_MNK,
|
||||
EpilogueTile_MN,
|
||||
@ -294,11 +295,11 @@ struct Sm90TmaBuilderImpl {
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
CopyOpG2S,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<UnderlyingGmemStrideTypeC, ElementC, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_load_op_for_source<UnderlyingGmemStrideTypeC, ElementC>()),
|
||||
CopyOpS2G,
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>()),
|
||||
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<UnderlyingGmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
||||
decltype(detail::sm90_get_smem_store_op_for_accumulator<UnderlyingGmemStrideTypeD, ElementD>()),
|
||||
CopyAtomC
|
||||
>;
|
||||
};
|
||||
@ -483,7 +484,7 @@ struct CollectiveBuilder<
|
||||
FusionOperation,
|
||||
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<Schedule, PtrArrayTmaWarpSpecializedCooperative> >> {
|
||||
detail::sm90_is_ptr_array_tma_v<Schedule>>> {
|
||||
private:
|
||||
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
||||
fusion::get_element_aux_t<FusionOperation>, ElementD_>;
|
||||
|
||||
@ -71,6 +71,62 @@ is_im2col() {
|
||||
|| cute::is_same_v<Stride, cutlass::detail::TagToStrideC_t<cutlass::layout::TensorNDHWC>>;
|
||||
}
|
||||
|
||||
template<class Schedule>
|
||||
struct sm90_is_ptr_array_tma : cute::false_type {};
|
||||
|
||||
template<>
|
||||
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecializedCooperative> : cute::true_type {};
|
||||
|
||||
template<>
|
||||
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecializedPingpong> : cute::true_type {};
|
||||
|
||||
template<>
|
||||
struct sm90_is_ptr_array_tma<PtrArrayTmaWarpSpecialized> : cute::true_type {};
|
||||
|
||||
template<class Schedule>
|
||||
static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma<Schedule>::value;
|
||||
|
||||
template<class Schedule>
|
||||
struct sm90_is_ptr_array_tma_cooperative : cute::false_type {};
|
||||
|
||||
template<>
|
||||
struct sm90_is_ptr_array_tma_cooperative<PtrArrayTmaWarpSpecializedCooperative> : cute::true_type {};
|
||||
|
||||
template<class Schedule>
|
||||
static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative<Schedule>::value;
|
||||
|
||||
template<class Schedule>
|
||||
struct sm90_is_ptr_array_tma_pingpong : cute::false_type {};
|
||||
|
||||
template<>
|
||||
struct sm90_is_ptr_array_tma_pingpong<PtrArrayTmaWarpSpecializedPingpong> : cute::true_type {};
|
||||
|
||||
template<class Schedule>
|
||||
static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong<Schedule>::value;
|
||||
|
||||
template<class DispatchPolicy>
|
||||
struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {};
|
||||
|
||||
template<
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
int NumEpilogueWarpGroups
|
||||
>
|
||||
struct sm90_is_ptr_array_tma_dispatch_policy<
|
||||
Sm90PtrArrayTmaWarpSpecialized<StagesC,
|
||||
StagesD,
|
||||
FragmentSize,
|
||||
ReuseSmemC,
|
||||
DelayTmaStore,
|
||||
NumEpilogueWarpGroups>>
|
||||
: cute::true_type {};
|
||||
|
||||
template<class DispatchPolicy>
|
||||
static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy<DispatchPolicy>::value;
|
||||
|
||||
using cutlass::atomic_maximum;
|
||||
|
||||
template <class T>
|
||||
@ -79,14 +135,11 @@ static constexpr int elements_per_access_v = cutlass::sizeof_bits<uint32_t>::val
|
||||
template <class EpilogueSchedule>
|
||||
static constexpr bool sm90_is_cooperative_v =
|
||||
cute::is_base_of_v<cutlass::epilogue::TmaWarpSpecializedCooperative, EpilogueSchedule> ||
|
||||
cute::is_base_of_v<cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative, EpilogueSchedule>;
|
||||
|
||||
template <class EpilogueSchedule>
|
||||
static constexpr bool sm90_is_tma_ptr_array_v =
|
||||
cute::is_base_of_v<cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative, EpilogueSchedule>;
|
||||
sm90_is_ptr_array_tma_cooperative_v<EpilogueSchedule>;
|
||||
|
||||
template <class EpilogueSchedule>
|
||||
static constexpr bool sm90_is_warp_specialized_v =
|
||||
(!sm90_is_ptr_array_tma_cooperative_v<EpilogueSchedule> && sm90_is_ptr_array_tma_v<EpilogueSchedule>) ||
|
||||
cute::is_base_of_v<cutlass::epilogue::TmaWarpSpecialized, EpilogueSchedule>;
|
||||
|
||||
template <class GmemLayoutTag>
|
||||
@ -199,7 +252,11 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
load_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, [[maybe_unused]] int32_t const sm_idx) const {
|
||||
load_init(
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx) {
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
@ -243,7 +300,7 @@ public:
|
||||
[[maybe_unused]] TensorStorage& shared_tensors,
|
||||
[[maybe_unused]] TensorMapC const& load_tensormap,
|
||||
[[maybe_unused]] int subtile_idx=-1,
|
||||
[[maybe_unused]] bool return_prior_state = false)
|
||||
[[maybe_unused]] bool wait = false)
|
||||
{
|
||||
return load_pipe_producer_state;
|
||||
}
|
||||
@ -257,8 +314,12 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
store_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count,
|
||||
[[maybe_unused]] int32_t const sm_idx) const {
|
||||
store_init(
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] int32_t sm_count,
|
||||
[[maybe_unused]] int32_t sm_idx,
|
||||
[[maybe_unused]] int32_t warp_group_idx) {
|
||||
return cute::make_tuple(nullptr);
|
||||
}
|
||||
|
||||
@ -369,22 +430,25 @@ public:
|
||||
|
||||
// Dummy methods to perform different parts of TMA/Tensormap modifications
|
||||
|
||||
template <bool IsLoad>
|
||||
template <bool IsLoad, class ProblemShapeMNKL>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_perform_update(
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
|
||||
[[maybe_unused]] int32_t next_batch) { }
|
||||
[[maybe_unused]] ProblemShapeMNKL problem_shape,
|
||||
[[maybe_unused]] int32_t next_batch,
|
||||
[[maybe_unused]] int32_t warp_group_idx) { }
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_cp_fence_release(
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
|
||||
[[maybe_unused]] uint32_t lane_predicate) { }
|
||||
[[maybe_unused]] uint32_t lane_predicate,
|
||||
[[maybe_unused]] int32_t warp_group_idx) { }
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -44,9 +44,10 @@
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/detail/layout.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -62,6 +63,7 @@ template <
|
||||
int FragmentSize_,
|
||||
bool ReuseSmemC_,
|
||||
bool DelayTmaStore_,
|
||||
int NumEpilogueWarpGroups_,
|
||||
class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K)
|
||||
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N)
|
||||
class ElementC_,
|
||||
@ -78,7 +80,13 @@ template <
|
||||
class CopyAtomC_
|
||||
>
|
||||
class CollectiveEpilogue<
|
||||
Sm90PtrArrayTmaWarpSpecialized<StagesC_,StagesD_,FragmentSize_,ReuseSmemC_,DelayTmaStore_>,
|
||||
Sm90PtrArrayTmaWarpSpecialized<StagesC_,
|
||||
StagesD_,
|
||||
FragmentSize_,
|
||||
ReuseSmemC_,
|
||||
DelayTmaStore_,
|
||||
NumEpilogueWarpGroups_
|
||||
>,
|
||||
CtaTileMNK_,
|
||||
EpilogueTile_,
|
||||
ElementC_,
|
||||
@ -98,7 +106,13 @@ public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,StagesD_,FragmentSize_,ReuseSmemC_,DelayTmaStore_>;
|
||||
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,
|
||||
StagesD_,
|
||||
FragmentSize_,
|
||||
ReuseSmemC_,
|
||||
DelayTmaStore_,
|
||||
NumEpilogueWarpGroups_
|
||||
>;
|
||||
using CtaTileMNK = CtaTileMNK_;
|
||||
using EpilogueTile = EpilogueTile_;
|
||||
using FusionCallbacks = FusionCallbacks_;
|
||||
@ -201,6 +215,8 @@ public:
|
||||
(size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof_bits<SmemElementC>::value)) / 8;
|
||||
constexpr static bool RequiresTransactionBytes = true;
|
||||
|
||||
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
|
||||
|
||||
// TMA pipeline for storing D
|
||||
using StorePipeline = cute::conditional_t<ReuseSmemC,
|
||||
cutlass::PipelineTmaStore<StagesC, StagesD-1>,
|
||||
@ -219,7 +235,7 @@ public:
|
||||
|
||||
struct TensorMapStorage : cute::aligned_struct<128> {
|
||||
cute::TmaDescriptor smem_tensormap_C;
|
||||
cute::TmaDescriptor smem_tensormap_D;
|
||||
cute::array<cute::TmaDescriptor, NumEpilogueWarpGroups> smem_tensormap_D;
|
||||
} tensormaps;
|
||||
|
||||
using PipelineStorage = typename LoadPipeline::SharedStorage;
|
||||
@ -229,6 +245,8 @@ public:
|
||||
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideC, StrideC>;
|
||||
|
||||
// Host side epilogue arguments
|
||||
struct Arguments {
|
||||
typename FusionCallbacks::Arguments thread{};
|
||||
@ -261,7 +279,9 @@ public:
|
||||
TMA_D tma_store_d;
|
||||
cute::TmaDescriptor* tensormaps;
|
||||
ElementC const** ptr_C;
|
||||
StrideC dC;
|
||||
ElementD** ptr_D;
|
||||
StrideD dD;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
};
|
||||
|
||||
@ -275,36 +295,57 @@ public:
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
[[maybe_unused]] void* workspace) {
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1);
|
||||
auto [M, N, K, mock_L] = problem_shape_MNKL;
|
||||
// Manage batches/groups through pointers to input matricies
|
||||
mock_L = 1;
|
||||
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
|
||||
// These will be replaced with correct values before the initial tma load.
|
||||
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
|
||||
auto init_M = get<0>(init_shape);
|
||||
auto init_N = get<1>(init_shape);
|
||||
auto init_L = get<3>(init_shape);
|
||||
|
||||
static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D");
|
||||
|
||||
InternalStrideC stride_c;
|
||||
InternalStrideD stride_d;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
|
||||
stride_c = InternalStrideC{};
|
||||
stride_d = InternalStrideD{};
|
||||
}
|
||||
else {
|
||||
// Tensor shapes for Ptr-Array are initialized correctly only here.
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1);
|
||||
init_M = get<0>(problem_shape_MNKL);
|
||||
init_N = get<1>(problem_shape_MNKL);
|
||||
init_L = get<3>(problem_shape_MNKL);
|
||||
|
||||
stride_c = args.dC;
|
||||
stride_d = args.dD;
|
||||
}
|
||||
|
||||
uint32_t transaction_bytes = TmaTransactionBytes;
|
||||
typename Params::TMA_C tma_load_c = {};
|
||||
if constexpr (is_source_supported) {
|
||||
ElementC const* ptr_C_first_batch = reinterpret_cast<ElementC const*>(args.ptr_C);
|
||||
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC, _0{})));
|
||||
tma_load_c = make_tma_copy_C_sm90(
|
||||
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{})));
|
||||
tma_load_c = make_tma_copy(
|
||||
CopyOpG2S{},
|
||||
tensor_c,
|
||||
take<0,2>(SmemLayoutC{}),
|
||||
EpilogueTile{});
|
||||
EpilogueTile{},
|
||||
_1{});
|
||||
|
||||
}
|
||||
|
||||
typename Params::TMA_D tma_store_d;
|
||||
if constexpr (is_destination_supported) {
|
||||
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
|
||||
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD, _0{})));
|
||||
tma_store_d = make_tma_copy_C_sm90(
|
||||
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));
|
||||
tma_store_d = make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_d,
|
||||
take<0,2>(SmemLayoutD{}),
|
||||
EpilogueTile{});
|
||||
EpilogueTile{},
|
||||
_1{});
|
||||
}
|
||||
|
||||
auto fusion_workspace = static_cast<char*>(workspace);
|
||||
@ -318,7 +359,9 @@ public:
|
||||
tma_store_d,
|
||||
tma_descriptor_workspace,
|
||||
args.ptr_C,
|
||||
args.dC,
|
||||
args.ptr_D,
|
||||
args.dD,
|
||||
transaction_bytes,
|
||||
};
|
||||
}
|
||||
@ -326,10 +369,11 @@ public:
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
|
||||
constexpr uint32_t NumInputTensors = cute::is_void_v<ElementC> ? 1 : 2;
|
||||
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
|
||||
auto descriptors_shape = cute::make_shape(sm_count, Int<NumInputTensors>{});
|
||||
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
|
||||
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
|
||||
return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
|
||||
return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
@ -342,30 +386,40 @@ public:
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
ProblemShape problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
if constexpr (is_destination_supported) {
|
||||
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
|
||||
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
|
||||
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
|
||||
}
|
||||
bool fusion_implementable = true;
|
||||
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
|
||||
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
|
||||
if (problem_shape.is_host_problem_shape_available()) {
|
||||
for (int i = 0; i < problem_shape.groups(); ++i) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
if constexpr (is_destination_supported) {
|
||||
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
|
||||
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
|
||||
}
|
||||
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
|
||||
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
|
||||
}
|
||||
|
||||
fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n");
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
|
||||
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
|
||||
|
||||
if (!fusion_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
||||
}
|
||||
@ -414,10 +468,14 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
|
||||
load_init(
|
||||
Params const& params,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
int32_t sm_count,
|
||||
int32_t sm_idx) {
|
||||
// Initialize tma for loading
|
||||
constexpr bool IsLoad = true;
|
||||
auto load_tensormaps = tensormaps_init<IsLoad>(params, sm_count, sm_idx);
|
||||
auto load_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, 0);
|
||||
return load_tensormaps;
|
||||
}
|
||||
|
||||
@ -426,7 +484,8 @@ public:
|
||||
class TileShapeMNK,
|
||||
class TileCoordMNKL,
|
||||
class TiledMma,
|
||||
class TensorMapC
|
||||
class TensorMapC,
|
||||
__CUTE_REQUIRES(std::is_pointer_v<TensorMapC>)
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
load(
|
||||
@ -440,7 +499,7 @@ public:
|
||||
TensorStorage& shared_tensors,
|
||||
TensorMapC const& load_tensormap,
|
||||
int subtile_idx=-1,
|
||||
bool return_prior_state = false) {
|
||||
bool wait_until_load_finishes = false) {
|
||||
using namespace cute;
|
||||
|
||||
// Indexing variables
|
||||
@ -478,17 +537,21 @@ public:
|
||||
auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args);
|
||||
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
|
||||
|
||||
LoadPipelineState last_load_producer_state = load_pipe_producer_state;
|
||||
|
||||
// Predication for TMA load (one thread issues TMA load)
|
||||
bool issue_tma_load = cute::elect_one_sync();
|
||||
|
||||
// Acquire the lock for the first stage
|
||||
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
|
||||
load_pipeline.producer_acquire(load_pipe_producer_state);
|
||||
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
|
||||
|
||||
// Pre-loop fusion callback entry point
|
||||
pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load);
|
||||
|
||||
auto prior_state = load_pipe_producer_state;
|
||||
LoadPipelineState prior_state = load_pipe_producer_state;
|
||||
|
||||
bool did_load = false;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
|
||||
@ -506,15 +569,18 @@ public:
|
||||
pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load);
|
||||
|
||||
// Execute the TMA load for C if needed
|
||||
if (issue_tma_load && is_C_load_needed) {
|
||||
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
|
||||
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
|
||||
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
|
||||
if (is_C_load_needed) {
|
||||
if (issue_tma_load) {
|
||||
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
|
||||
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
|
||||
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
|
||||
}
|
||||
last_load_producer_state = load_pipe_producer_state;
|
||||
did_load = true;
|
||||
}
|
||||
|
||||
// Commit TMA loads for this stage and release the lock
|
||||
load_pipeline.producer_commit(load_pipe_producer_state);
|
||||
prior_state = load_pipe_producer_state;
|
||||
++load_pipe_producer_state;
|
||||
}
|
||||
}
|
||||
@ -522,17 +588,24 @@ public:
|
||||
// Post-loop fusion callback entry point
|
||||
pld_callbacks.end();
|
||||
|
||||
if (not return_prior_state) {
|
||||
return load_pipe_producer_state;
|
||||
} else {
|
||||
return prior_state;
|
||||
if (wait_until_load_finishes && did_load) {
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
|
||||
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
|
||||
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
|
||||
}
|
||||
|
||||
return load_pipe_producer_state;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
load_tail(
|
||||
LoadPipeline load_pipeline,
|
||||
LoadPipelineState load_pipe_producer_state) {
|
||||
|
||||
if (!fusion_callbacks.is_producer_load_needed()) {
|
||||
return load_pipe_producer_state;
|
||||
}
|
||||
|
||||
bool issue_tma_load = cute::elect_one_sync();
|
||||
if (issue_tma_load) {
|
||||
load_pipeline.producer_tail(load_pipe_producer_state);
|
||||
@ -564,6 +637,7 @@ public:
|
||||
TensorStorage& shared_tensors,
|
||||
TensorMapD const& store_tensormap,
|
||||
int subtile_idx=-1) {
|
||||
|
||||
using namespace cute;
|
||||
using ElementAccumulator = typename AccEngine::value_type;
|
||||
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
|
||||
@ -869,11 +943,22 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
store_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
|
||||
// Initialize tma
|
||||
constexpr bool IsLoad = false;
|
||||
auto store_tensormaps = tensormaps_init<IsLoad>(params, sm_count, sm_idx);
|
||||
return store_tensormaps;
|
||||
store_init(
|
||||
Params const& params,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
int32_t sm_count,
|
||||
int32_t sm_idx,
|
||||
int32_t warp_group_idx) {
|
||||
int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
|
||||
// Since only one warp issues TMA store, we only need that one warp to initialize tensormaps
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
// Initialize tma
|
||||
constexpr bool IsLoad = false;
|
||||
auto store_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx);
|
||||
return store_tensormaps;
|
||||
}
|
||||
TmaDescriptor* null_tma_desc = nullptr;
|
||||
return cute::make_tuple(null_tma_desc);
|
||||
}
|
||||
|
||||
//
|
||||
@ -882,53 +967,45 @@ public:
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE auto
|
||||
tensormaps_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const {
|
||||
cute::TmaDescriptor* tma_desc = nullptr;
|
||||
cute::TmaDescriptor* gmem_tensormap = params.tensormaps;
|
||||
tensormaps_init(
|
||||
Params const& params,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
int32_t sm_count,
|
||||
int32_t sm_idx,
|
||||
int32_t warp_group_idx) {
|
||||
|
||||
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
|
||||
Layout desc_layout = make_layout(make_shape(sm_count, Int<NumInputTensors>{}));
|
||||
|
||||
Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors)
|
||||
|
||||
if constexpr (IsLoad) {
|
||||
if (not cute::is_void_v<ElementC>) {
|
||||
tma_desc = &gmem_tensormap[sm_idx];
|
||||
constexpr int C_tensormap_index = NumEpilogueWarpGroups;
|
||||
Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{});
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps from params to gmem for modification later
|
||||
Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gC_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{});
|
||||
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(gC_tensormap));
|
||||
// Bringing tensormaps from params to smem for modification later
|
||||
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(sC_tensormap));
|
||||
}
|
||||
__syncwarp();
|
||||
return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index));
|
||||
}
|
||||
} else {
|
||||
int const offset_Ddesc = cute::is_void_v<ElementC> ? 0 : sm_count;
|
||||
tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc];
|
||||
TmaDescriptor* null_tma_desc = nullptr;
|
||||
return cute::make_tuple(null_tma_desc);
|
||||
}
|
||||
else {
|
||||
Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{});
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps from params to gmem for modification later
|
||||
Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gD_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{});
|
||||
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(gD_tensormap));
|
||||
// Bringing tensormaps from params to smem for modification later
|
||||
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(sD_tensormap));
|
||||
}
|
||||
__syncwarp();
|
||||
return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx));
|
||||
}
|
||||
|
||||
return cute::make_tuple(tma_desc);
|
||||
}
|
||||
|
||||
// Bringing tensormaps to smem (to be done by single thread)
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_fetch_to_smem(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
cute::TmaDescriptor const* tensormap) const {
|
||||
if constexpr (IsLoad) {
|
||||
if (not cute::is_void_v<ElementC>) {
|
||||
Tensor gC_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{});
|
||||
Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{});
|
||||
copy(recast<uint128_t>(gC_tensormap), recast<uint128_t>(sC_tensormap));
|
||||
}
|
||||
} else {
|
||||
Tensor gD_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{});
|
||||
Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{});
|
||||
copy(recast<uint128_t>(gD_tensormap), recast<uint128_t>(sD_tensormap));
|
||||
}
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
}
|
||||
|
||||
// Replace address for the global tensor (to be done by single thread)
|
||||
@ -936,35 +1013,99 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_address(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& params,
|
||||
int32_t next_batch) {
|
||||
int32_t next_batch,
|
||||
int32_t warp_group_idx) {
|
||||
// Replacing global_address for the next batch
|
||||
if constexpr (IsLoad) {
|
||||
if (not cute::is_void_v<ElementC>) {
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C,
|
||||
if constexpr (is_source_supported) {
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C,
|
||||
params.ptr_C[next_batch]);
|
||||
}
|
||||
} else {
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D,
|
||||
}
|
||||
else if constexpr (is_destination_supported) {
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
|
||||
params.ptr_D[next_batch]);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
|
||||
template <bool IsLoad, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_tensor_properties(
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& params,
|
||||
int32_t next_group,
|
||||
ProblemShape_MNKL problem_shape_mnkl,
|
||||
int32_t warp_group_idx) {
|
||||
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||
// Only consider dimensions and strides that we need to recalculate and replace for each group
|
||||
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
|
||||
static_assert(TensorRank == Int<3>{},
|
||||
"Descriptor modification for global dims & strides expects rank as 3.");
|
||||
|
||||
cute::array<uint32_t, TensorRank> prob_shape = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride = {0,0,0};
|
||||
|
||||
if constexpr (IsLoad) {
|
||||
if constexpr (is_source_supported) {
|
||||
ElementC const* ptr_C = nullptr;
|
||||
Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group]));
|
||||
|
||||
cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c,
|
||||
prob_shape, prob_stride);
|
||||
// Convert strides to byte strides
|
||||
for (uint64_t& stride : prob_stride) {
|
||||
stride = (stride * sizeof_bits_v<ElementC>) / 8;
|
||||
}
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C,
|
||||
prob_shape,
|
||||
prob_stride);
|
||||
}
|
||||
}
|
||||
else if constexpr (is_destination_supported) {
|
||||
ElementD const* ptr_D = nullptr;
|
||||
|
||||
// tma_store_c should be a gmem_tensor, second argument should be a stride
|
||||
|
||||
Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group]));
|
||||
|
||||
cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d,
|
||||
prob_shape, prob_stride);
|
||||
// Convert strides to byte strides
|
||||
for (uint64_t& stride : prob_stride) {
|
||||
stride = (stride * sizeof_bits_v<ElementD>) / 8;
|
||||
}
|
||||
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
|
||||
prob_shape,
|
||||
prob_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool IsLoad, class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_perform_update(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& params,
|
||||
cute::TmaDescriptor const* tensormap,
|
||||
int32_t next_batch) {
|
||||
ProblemShape_MNKL problem_shape_mnkl,
|
||||
int32_t next_batch,
|
||||
int32_t warp_group_idx) {
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps to smem
|
||||
tensormaps_fetch_to_smem<IsLoad>(shared_tensormap, tensormap);
|
||||
|
||||
// Replacing global_address for the next batch
|
||||
tensormaps_replace_global_address<IsLoad>(shared_tensormap, params, next_batch);
|
||||
tensormaps_replace_global_address<IsLoad>(shared_tensormaps, params, next_batch, warp_group_idx);
|
||||
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Replacing global dims and strides for the next batch
|
||||
tensormaps_replace_global_tensor_properties<IsLoad>(
|
||||
shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -972,16 +1113,18 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_cp_fence_release(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
cute::TmaDescriptor const* tensormap,
|
||||
[[maybe_unused]] uint32_t lane_predicate) {
|
||||
[[maybe_unused]] uint32_t lane_predicate,
|
||||
int32_t warp_group_idx = 0) {
|
||||
// Entire warp must do this (ie its aligned)
|
||||
if constexpr (IsLoad) {
|
||||
if (not cute::is_void_v<ElementC>) {
|
||||
tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C);
|
||||
if constexpr (is_source_supported) {
|
||||
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C);
|
||||
}
|
||||
} else {
|
||||
tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D);
|
||||
}
|
||||
else if constexpr (is_destination_supported) {
|
||||
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -990,10 +1133,11 @@ public:
|
||||
void
|
||||
tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) {
|
||||
if constexpr (IsLoad) {
|
||||
if (not cute::is_void_v<ElementC>) {
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
cute::tma_descriptor_fence_acquire(tensormap);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
cute::tma_descriptor_fence_acquire(tensormap);
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,7 +51,21 @@ struct PtrArrayNoSmemWarpSpecialized {};
|
||||
struct PtrArrayPlanarComplexNoSmemWarpSpecialized {};
|
||||
struct TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperative {};
|
||||
struct PtrArrayTmaWarpSpecializedCooperative {};
|
||||
|
||||
struct PtrArrayTmaWarpSpecializedCooperative {
|
||||
static constexpr int NumEpilogueWarpGroups = 2;
|
||||
};
|
||||
|
||||
// Standard warp specialized epilogue
|
||||
struct PtrArrayTmaWarpSpecialized {
|
||||
static constexpr int NumEpilogueWarpGroups = 1;
|
||||
};
|
||||
|
||||
// Pingpong kernel epilogue
|
||||
struct PtrArrayTmaWarpSpecializedPingpong {
|
||||
static constexpr int NumEpilogueWarpGroups = 2;
|
||||
};
|
||||
|
||||
// DEPRECATED schedules, will be removed in next release
|
||||
struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {};
|
||||
@ -151,7 +165,8 @@ template<
|
||||
int StagesD_,
|
||||
int FragmentSize_,
|
||||
bool ReuseSmemC_,
|
||||
bool DelayTmaStore_
|
||||
bool DelayTmaStore_,
|
||||
int NumEpilogueWarpGroups_
|
||||
>
|
||||
struct Sm90PtrArrayTmaWarpSpecialized {
|
||||
constexpr static int StagesC = StagesC_;
|
||||
@ -159,6 +174,7 @@ struct Sm90PtrArrayTmaWarpSpecialized {
|
||||
constexpr static int FragmentSize = FragmentSize_;
|
||||
constexpr static bool ReuseSmemC = ReuseSmemC_;
|
||||
constexpr static bool DelayTmaStore = DelayTmaStore_;
|
||||
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
|
||||
};
|
||||
|
||||
// DEPRECATED policies, will be removed in next release
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -179,6 +179,89 @@ struct FusionCallbacks<
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch
|
||||
template<
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource = ElementOutput,
|
||||
class ElementScalar = ElementCompute,
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90LinearCombinationPtrArray =
|
||||
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int>>, // beta
|
||||
Sm90SrcFetch<ElementSource>, // C
|
||||
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int>>, // alpha
|
||||
Sm90AccFetch // acc
|
||||
>
|
||||
>;
|
||||
|
||||
template <
|
||||
int StagesC,
|
||||
int StagesD,
|
||||
int FragmentSize,
|
||||
bool ReuseSmemC,
|
||||
bool DelayTmaStore,
|
||||
int NumEpilogueWarpGroups,
|
||||
class ElementOutput,
|
||||
class ElementCompute,
|
||||
class ElementSource,
|
||||
class ElementScalar,
|
||||
FloatRoundStyle RoundStyle,
|
||||
class CtaTileShapeMNK,
|
||||
class EpilogueTile
|
||||
>
|
||||
struct FusionCallbacks<
|
||||
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
|
||||
StagesD,
|
||||
FragmentSize,
|
||||
ReuseSmemC,
|
||||
DelayTmaStore,
|
||||
NumEpilogueWarpGroups
|
||||
>,
|
||||
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>,
|
||||
CtaTileShapeMNK,
|
||||
EpilogueTile
|
||||
> : Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> {
|
||||
|
||||
using Impl = Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>;
|
||||
|
||||
struct Arguments {
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
ElementScalar const* alpha_ptr = nullptr;
|
||||
ElementScalar const* beta_ptr = nullptr;
|
||||
ElementScalar const* const* alpha_ptr_array = nullptr;
|
||||
ElementScalar const* const* beta_ptr_array = nullptr;
|
||||
|
||||
using StrideAlpha = Stride<_0,_0,int>;
|
||||
using StrideBeta = Stride<_0,_0,int>;
|
||||
StrideAlpha dAlpha = {_0{}, _0{}, 0};
|
||||
StrideBeta dBeta = {_0{}, _0{}, 0};
|
||||
|
||||
operator typename Impl::Arguments() const {
|
||||
return
|
||||
{ // ternary op : beta * C + (alpha * acc)
|
||||
{{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta
|
||||
{}, // leaf args : C
|
||||
{ // binary op : alpha * acc
|
||||
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
|
||||
{}, // leaf args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
};
|
||||
|
||||
// Ctor inheritance
|
||||
using Impl::Impl;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// D = activation(alpha * acc + beta * C)
|
||||
template<
|
||||
template <class> class ActivationFn,
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "sm90_visitor_tma_warpspecialized.hpp"
|
||||
@ -514,7 +515,8 @@ private:
|
||||
|
||||
if (params_ptr->scalar_ptrs[0] != nullptr) {
|
||||
scalar = params_ptr->scalar_ptrs[0][l_offset];
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
// batch stride is ignored for nullptr fallback
|
||||
scalar = params_ptr->scalars[0];
|
||||
}
|
||||
@ -541,6 +543,169 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
// Scalar broadcast
|
||||
// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors
|
||||
template<
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_0,_0>,
|
||||
int BroadcastCount = 1,
|
||||
template <class> class ReductionFn = multiplies
|
||||
>
|
||||
struct Sm90ScalarBroadcastPtrArray {
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{});
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
struct Arguments {
|
||||
Element scalars[BroadcastCount] = {};
|
||||
Element const* scalar_ptrs[BroadcastCount] = {};
|
||||
Element const* const* scalar_ptr_arrays[BroadcastCount] = {};
|
||||
StrideMNL dScalar[BroadcastCount] = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
// producer load is needed if Element is not void and we have multiple scalars
|
||||
return !cute::is_void_v<Element> and size<2>(params_ptr->dScalar[0]) != 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// This must be called after update_scalar is called
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return scalar == Element(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScalarBroadcastPtrArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) {
|
||||
// Get the scalar for non-batched broadcast
|
||||
if (size<2>(params_ptr->dScalar[0]) == 0) {
|
||||
update_scalar();
|
||||
}
|
||||
}
|
||||
|
||||
Element scalar;
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
// Get the scalar for batched broadcast
|
||||
if (get<2>(params_ptr->dScalar[0]) != 0) {
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
|
||||
update_scalar(l_coord);
|
||||
}
|
||||
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(Element scalar)
|
||||
: scalar(scalar) {}
|
||||
|
||||
Element scalar;
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_scalar;
|
||||
frg_scalar.fill(scalar);
|
||||
|
||||
return frg_scalar;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
// Get the scalar for batched broadcast
|
||||
if (get<2>(params_ptr->dScalar[0]) != 0) {
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
|
||||
update_scalar(l_coord);
|
||||
}
|
||||
|
||||
return ConsumerStoreCallbacks(scalar);
|
||||
}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE void
|
||||
update_scalar(int l_coord = 0) {
|
||||
int l_offset = l_coord * size<2>(params_ptr->dScalar[0]);
|
||||
|
||||
if (params_ptr->scalar_ptr_arrays[0] != nullptr) {
|
||||
scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]);
|
||||
}
|
||||
else if (params_ptr->scalar_ptrs[0] != nullptr) {
|
||||
scalar = params_ptr->scalar_ptrs[0][l_offset];
|
||||
}
|
||||
else {
|
||||
// batch stride is ignored for nullptr fallback
|
||||
scalar = params_ptr->scalars[0];
|
||||
}
|
||||
|
||||
// Do reduction over multiple broadcasts if necessary
|
||||
ReductionFn<Element> reduction_fn;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < BroadcastCount; ++i) {
|
||||
|
||||
if (params_ptr->scalar_ptr_arrays[i] != nullptr) {
|
||||
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
||||
scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset]));
|
||||
}
|
||||
if (params_ptr->scalar_ptrs[i] != nullptr) {
|
||||
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
||||
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
|
||||
}
|
||||
else {
|
||||
// batch stride is ignored for nullptr fallback
|
||||
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
@ -31,6 +31,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/pipeline/sm90_pipeline.hpp"
|
||||
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
@ -177,10 +181,12 @@ struct CollectiveBuilder<
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
|
||||
(cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecialized,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedPingpong,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>) &&
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
@ -191,10 +197,12 @@ struct CollectiveBuilder<
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
|
||||
"KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
@ -203,8 +211,10 @@ struct CollectiveBuilder<
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> || IsArrayOfPointersGemm,
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
@ -218,7 +228,10 @@ struct CollectiveBuilder<
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
@ -505,10 +518,12 @@ struct CollectiveBuilder<
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>
|
||||
cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedFP8FastAccum,
|
||||
KernelTmaWarpSpecializedPingpongFP8FastAccum,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@ -526,10 +541,15 @@ struct CollectiveBuilder<
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum>;
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative, Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
@ -542,7 +562,11 @@ struct CollectiveBuilder<
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<Sm90ReducedSmemCapacityBytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
|
||||
@ -623,56 +623,42 @@ struct CollectiveMma<
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE auto
|
||||
tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const {
|
||||
tensormaps_init(
|
||||
Params const& mainloop_params,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
int32_t sm_count,
|
||||
int32_t sm_idx) {
|
||||
cute::TmaDescriptor* gmem_tensormap = reinterpret_cast<cute::TmaDescriptor*>(mainloop_params.tensormaps);
|
||||
|
||||
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
|
||||
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps from params to gmem for modification later
|
||||
// Bringing tensormaps from params to smem for modification later
|
||||
Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{});
|
||||
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
|
||||
Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
||||
Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{});
|
||||
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
|
||||
|
||||
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(gA_tensormap));
|
||||
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(gB_tensormap));
|
||||
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
|
||||
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
return cute::make_tuple(tma_desc_a, tma_desc_b);
|
||||
}
|
||||
|
||||
// Bringing tensormaps to smem (to be done by single thread)
|
||||
template <class TensorMapA, class TensorMapB>
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_fetch_to_smem(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) const {
|
||||
Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{});
|
||||
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{});
|
||||
Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{});
|
||||
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{});
|
||||
|
||||
copy(recast<uint128_t>(gA_tensormap), recast<uint128_t>(sA_tensormap));
|
||||
copy(recast<uint128_t>(gB_tensormap), recast<uint128_t>(sB_tensormap));
|
||||
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
}
|
||||
|
||||
// Replace address for the global tensor (to be done by single thread)
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_address(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& mainloop_params,
|
||||
int32_t next_batch) {
|
||||
// Replacing global_address for the next batch
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A,
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
||||
mainloop_params.ptr_A[next_batch]);
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B,
|
||||
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
||||
mainloop_params.ptr_B[next_batch]);
|
||||
}
|
||||
|
||||
@ -681,7 +667,7 @@ struct CollectiveMma<
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_replace_global_tensor_properties(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& mainloop_params,
|
||||
int32_t next_group,
|
||||
ProblemShape_MNKL problem_shape_mnkl) {
|
||||
@ -716,10 +702,10 @@ struct CollectiveMma<
|
||||
stride = (stride * sizeof_bits_v<InternalElementB>) / 8;
|
||||
}
|
||||
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
||||
prob_shape_A,
|
||||
prob_stride_A);
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B,
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
||||
prob_shape_B,
|
||||
prob_stride_B);
|
||||
}
|
||||
@ -728,21 +714,19 @@ struct CollectiveMma<
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_perform_update(
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
Params const& mainloop_params,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
|
||||
ProblemShape_MNKL problem_shape_mnkl,
|
||||
int32_t next_batch) {
|
||||
if (cute::elect_one_sync()) {
|
||||
// Bringing tensormaps to smem
|
||||
tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps);
|
||||
|
||||
// Replacing global_address for the next batch
|
||||
tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch);
|
||||
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
|
||||
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Replacing global dims and strides for the next batch
|
||||
tensormaps_replace_global_tensor_properties(shared_tensormap,
|
||||
tensormaps_replace_global_tensor_properties(shared_tensormaps,
|
||||
mainloop_params, next_batch, problem_shape_mnkl);
|
||||
}
|
||||
}
|
||||
@ -752,11 +736,11 @@ struct CollectiveMma<
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
tensormaps_cp_fence_release (
|
||||
TensorMapStorage& shared_tensormap,
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
|
||||
// Entire warp must do this (i.e. it's aligned)
|
||||
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A);
|
||||
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B);
|
||||
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
|
||||
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
|
||||
}
|
||||
|
||||
// The entire warp must call this function collectively (that is, the instructions are aligned)
|
||||
|
||||
@ -98,6 +98,7 @@ struct KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -111,6 +112,7 @@ struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum : KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
|
||||
// Policies to opt into mixed type GEMMs
|
||||
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
||||
@ -286,8 +288,9 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
static_assert(
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule>,
|
||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelSchedule>,
|
||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -61,5 +61,6 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -46,6 +46,8 @@
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp"
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
@ -73,6 +75,9 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
static_assert(cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>);
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
@ -119,8 +124,9 @@ public:
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{}));
|
||||
static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
@ -215,11 +221,11 @@ public:
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count);
|
||||
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
|
||||
@ -275,9 +281,6 @@ public:
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
@ -286,6 +289,9 @@ public:
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
@ -363,6 +369,12 @@ public:
|
||||
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
|
||||
static_assert(NumMmaWarpGroups == 2, "Cooperative kernels currently only support NumMmaWarpGroups == 2");
|
||||
|
||||
if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v<typename CollectiveEpilogue::DispatchPolicy>) {
|
||||
static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups,
|
||||
"Tiled MmA does not match expected warp groups performing the epilogue");
|
||||
}
|
||||
|
||||
static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
@ -391,7 +403,8 @@ public:
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto warp_group_idx = canonical_warp_group_idx();
|
||||
auto warp_group_role = WarpGroupRole(warp_group_idx);
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
@ -466,7 +479,9 @@ public:
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
@ -484,7 +499,7 @@ public:
|
||||
}
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
@ -510,7 +525,7 @@ public:
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
|
||||
// Fetch a copy of tensormaps for the CTA
|
||||
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, sm_count, sm_idx);
|
||||
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx);
|
||||
|
||||
// Update tensormap for the initial batch for the CTA
|
||||
if (work_tile_info.is_valid()) {
|
||||
@ -578,7 +593,7 @@ public:
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
curr_batch = next_batch;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), Int<1>{});
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
|
||||
}
|
||||
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
// Since this state is waiting for loads to finish, it must start in the inverted phase.
|
||||
@ -610,7 +625,7 @@ public:
|
||||
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
|
||||
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, sm_count, sm_idx));
|
||||
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx));
|
||||
|
||||
bool did_batch_change = true;
|
||||
constexpr bool IsEpiLoad = true;
|
||||
@ -620,23 +635,26 @@ public:
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.L_idx
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
0
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate);
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
}
|
||||
|
||||
load_order_barrier.wait();
|
||||
while (work_tile_info.is_valid()) {
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
|
||||
bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler);
|
||||
// Get next work tile
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
|
||||
if (compute_epilogue) {
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
@ -649,6 +667,8 @@ public:
|
||||
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
|
||||
}
|
||||
|
||||
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
@ -660,36 +680,34 @@ public:
|
||||
shared_storage.tensors.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.reduction_subtile_idx(),
|
||||
true // return state prior to last advance
|
||||
wait
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
work_tile_info = next_work_tile_info;
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
// Wait for TMA load to finish before updating
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
|
||||
{epi_load_pipe_producer_state.index(), !epi_load_pipe_producer_state.phase(), epi_load_pipe_producer_state.count()};
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
epi_load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
|
||||
// tensormap update
|
||||
{
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_load_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
0
|
||||
);
|
||||
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.L_idx
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate);
|
||||
}
|
||||
|
||||
if(compute_epilogue) {
|
||||
epi_load_pipe_producer_state.advance(1);
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // Scheduler work fetch loop
|
||||
@ -702,32 +720,43 @@ public:
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
|
||||
|
||||
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
|
||||
bool do_store_tail = false;
|
||||
// Get a copy of tensormaps
|
||||
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, sm_count, sm_idx));
|
||||
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx));
|
||||
|
||||
bool did_batch_change = true;
|
||||
constexpr bool IsEpiLoad = false;
|
||||
|
||||
if (work_tile_info.is_valid()) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
work_tile_info.L_idx
|
||||
);
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate);
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
consumer_warp_group_idx
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
while (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
@ -743,6 +772,10 @@ public:
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
static_assert(cute::is_any_of_v<TileScheduler,
|
||||
detail::PersistentTileSchedulerSm90Group<ProblemShape>,
|
||||
detail::PersistentTileSchedulerSm90>);
|
||||
if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline,
|
||||
@ -764,18 +797,16 @@ public:
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
if (did_batch_change) {
|
||||
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
|
||||
}
|
||||
|
||||
if (did_batch_change) {
|
||||
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
|
||||
}
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
@ -804,20 +835,31 @@ public:
|
||||
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
work_tile_info.L_idx
|
||||
);
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
consumer_warp_group_idx
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate);
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Cooperative only needs TMA to complete at the very end of the kernel
|
||||
if (do_store_tail) {
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline,
|
||||
@ -829,7 +871,6 @@ public:
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,949 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/workspace.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal_decl.h"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ProblemShape_,
|
||||
class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_,
|
||||
class TileScheduler_
|
||||
>
|
||||
class GemmUniversal<
|
||||
ProblemShape_,
|
||||
CollectiveMainloop_,
|
||||
CollectiveEpilogue_,
|
||||
TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
|
||||
>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
static_assert(cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>);
|
||||
|
||||
static constexpr bool IsGdcEnabled = false;
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using InternalStrideA = typename CollectiveMainloop::InternalStrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using InternalStrideB = typename CollectiveMainloop::InternalStrideB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using Schedule = typename DispatchPolicy::Schedule;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using InternalStrideC = typename CollectiveEpilogue::InternalStrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using InternalStrideD = typename CollectiveEpilogue::InternalStrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
static_assert(cute::is_void_v<TileScheduler_>,
|
||||
"Ptr-Array Pingpong and Grouped Gemm Pingpong kernel only supports the default scheduler.");
|
||||
|
||||
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
|
||||
|
||||
using TileScheduler = cute::conditional_t<IsGroupedGemmKernel,
|
||||
typename detail::TileSchedulerSelector<
|
||||
GroupScheduler, ArchTag,
|
||||
TileShape, ClusterShape,
|
||||
ProblemShape>::Scheduler,
|
||||
typename detail::TileSchedulerSelector<
|
||||
void, ArchTag, TileShape, ClusterShape>::Scheduler>;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
using MathWarpGroupOrderBarrierSharedStorage = cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage<
|
||||
MathWarpGroupOrderBarrier::SequenceDepth,
|
||||
MathWarpGroupOrderBarrier::SequenceLength>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
} pipelines;
|
||||
|
||||
struct TensorMapStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage;
|
||||
using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage;
|
||||
|
||||
alignas(128) MainloopTensorMapStorage mainloop;
|
||||
alignas(128) EpilogueTensorMapStorage epilogue;
|
||||
} tensormaps;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
void* workspace{nullptr};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
ProblemShape problem_shapes = args.problem_shape;
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
|
||||
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
|
||||
// subtile will not be used, therefore separate reduction will not be enabled.
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
TileSchedulerParams scheduler;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
scheduler = TileScheduler::to_underlying_arguments(
|
||||
problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
}
|
||||
else {
|
||||
scheduler = TileScheduler::to_underlying_arguments(
|
||||
problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
}
|
||||
|
||||
return {
|
||||
args.mode,
|
||||
problem_shapes,
|
||||
CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace),
|
||||
hw_info,
|
||||
scheduler,
|
||||
workspace
|
||||
};
|
||||
}
|
||||
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = true;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Group GEMM currently only supports rank-3 problem shapes
|
||||
implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3);
|
||||
} else {
|
||||
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
|
||||
}
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_size = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM;
|
||||
dim3 grid_shape;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
grid_shape = TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
else {
|
||||
grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
return grid_shape;
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params const& params, char* smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(size(TiledMma{}) == 128, "Pingpong kernel must have TiledMMA operating using 128 threads.");
|
||||
static_assert(NumMmaWarpGroups == 2, "Pingpong kernels currently only support NumMmaWarpGroups == 2");
|
||||
|
||||
if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v<typename CollectiveEpilogue::DispatchPolicy>) {
|
||||
static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups,
|
||||
"Tiled MmA does not match expected warp groups performing the epilogue");
|
||||
}
|
||||
|
||||
static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2
|
||||
};
|
||||
enum class ProducerWarpRole {
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_idx = canonical_warp_group_idx();
|
||||
auto warp_group_role = WarpGroupRole(warp_group_idx);
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Note: Tma Descriptor Prefetch (from either const or param) is not applicable here
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
|
||||
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
|
||||
}
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id = warp_group_idx - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [] () {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
return [] () { cute::cluster_wait(); };
|
||||
}
|
||||
else {
|
||||
__syncthreads();
|
||||
return [] () {}; // do nothing
|
||||
}
|
||||
} ();
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
|
||||
if (not work_tile_info.is_valid()) {
|
||||
// When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups
|
||||
return;
|
||||
}
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
work_tile_info = next_work_tile_info;
|
||||
if (!work_tile_info.is_valid()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
|
||||
int32_t const mock_l_coord = 0;
|
||||
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
|
||||
// Fetch a copy of tensormaps for the CTA
|
||||
auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx);
|
||||
|
||||
// Update tensormap for the initial batch for the CTA
|
||||
if (work_tile_info.is_valid()) {
|
||||
collective_mainloop.tensormaps_perform_update(
|
||||
shared_storage.tensormaps.mainloop,
|
||||
params.mainloop,
|
||||
input_tensormaps,
|
||||
problem_shape_MNKL,
|
||||
curr_batch
|
||||
);
|
||||
// Ensure warp is converged before issuing tensormap fence release
|
||||
__syncwarp();
|
||||
// Entire warp must do this (i.e. it's aligned)
|
||||
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
|
||||
}
|
||||
|
||||
bool do_load_order_arrive = true;
|
||||
bool did_batch_change = true;
|
||||
while (work_tile_info.is_valid()) {
|
||||
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
work_tile_info = next_work_tile_info;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord);
|
||||
|
||||
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
|
||||
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
if (did_batch_change) {
|
||||
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
|
||||
}
|
||||
|
||||
collective_mainloop.load(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
load_inputs,
|
||||
input_tensormaps,
|
||||
blk_coord,
|
||||
k_tile_iter, work_k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
// Update starting pipeline state for the next tile
|
||||
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count - 1);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
work_tile_info = next_work_tile_info;
|
||||
auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx
|
||||
did_batch_change = next_batch != curr_batch;
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
curr_batch = next_batch;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
|
||||
}
|
||||
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
// Since this state is waiting for loads to finish, it must start in the inverted phase.
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state =
|
||||
{mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()};
|
||||
mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state);
|
||||
collective_mainloop.tensormaps_perform_update(
|
||||
shared_storage.tensormaps.mainloop,
|
||||
params.mainloop,
|
||||
input_tensormaps,
|
||||
problem_shape_MNKL,
|
||||
curr_batch
|
||||
);
|
||||
// Ensure warp is converged before issuing tensor replace
|
||||
__syncwarp();
|
||||
// Entire warp must do this (i.e. it's aligned)
|
||||
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
|
||||
}
|
||||
// Advance the producer state for the last remaining stage that was being waited for above
|
||||
mainloop_pipe_producer_state.advance(1);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
|
||||
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
|
||||
auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx));
|
||||
|
||||
bool did_batch_change = true;
|
||||
constexpr bool IsEpiLoad = true;
|
||||
|
||||
if (work_tile_info.is_valid()) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_load_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
0
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
}
|
||||
|
||||
load_order_barrier.wait();
|
||||
|
||||
while (work_tile_info.is_valid()) {
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
|
||||
// Get next work tile
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
if (did_batch_change) {
|
||||
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
|
||||
}
|
||||
|
||||
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
tiled_mma,
|
||||
lane_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.reduction_subtile_idx(),
|
||||
wait
|
||||
);
|
||||
}
|
||||
|
||||
work_tile_info = next_work_tile_info;
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
// tensormap update
|
||||
{
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_load_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
0
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
|
||||
|
||||
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
|
||||
int32_t const sm_count = params.hw_info.sm_count;
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
|
||||
bool do_store_tail = false;
|
||||
// Get a copy of tensormaps
|
||||
auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx));
|
||||
|
||||
bool did_batch_change = true;
|
||||
constexpr bool IsEpiLoad = false;
|
||||
|
||||
if (work_tile_info.is_valid()) {
|
||||
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
consumer_warp_group_idx
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
while (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
static_assert(cute::is_any_of_v<TileScheduler,
|
||||
detail::PersistentTileSchedulerSm90Group<ProblemShape>,
|
||||
detail::PersistentTileSchedulerSm90>);
|
||||
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
accumulators,
|
||||
work_k_tile_count,
|
||||
mma_thread_idx,
|
||||
shared_storage.tensors.mainloop,
|
||||
params.mainloop
|
||||
);
|
||||
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
work_k_tile_count
|
||||
);
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
if (did_batch_change) {
|
||||
collective_epilogue.tensormaps_fence_acquire<IsEpiLoad>(epi_store_tensormap);
|
||||
}
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
accumulators,
|
||||
tiled_mma,
|
||||
mma_thread_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
epi_store_tensormap,
|
||||
work_tile_info.reduction_subtile_idx()
|
||||
);
|
||||
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
|
||||
do_store_tail = true;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
work_tile_info = next_work_tile_info;
|
||||
|
||||
// Skip a tile for pingpong
|
||||
if (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
|
||||
// Go to next tile
|
||||
auto next_next_work_tile_info = scheduler.fetch_next_work(work_tile_info);
|
||||
|
||||
work_tile_info = next_next_work_tile_info;
|
||||
}
|
||||
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
}
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
shared_storage.tensormaps.epilogue,
|
||||
params.epilogue,
|
||||
epi_store_tensormap,
|
||||
problem_shape_MNKL,
|
||||
work_tile_info.L_idx,
|
||||
consumer_warp_group_idx
|
||||
);
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
|
||||
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
|
||||
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
|
||||
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] =
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state
|
||||
);
|
||||
|
||||
// Update starting load/store pipeline states for the next tile
|
||||
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@ -338,12 +338,14 @@ cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu
|
||||
)
|
||||
|
||||
# Group Gemm test
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu
|
||||
)
|
||||
|
||||
# Fused epilogue tests
|
||||
|
||||
@ -1005,7 +1005,7 @@ struct HostCollectiveEpilogue {
|
||||
stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t<LayoutTagAux>{}, cute::make_shape(M, N, 1));
|
||||
}
|
||||
|
||||
static_assert(!IsGroupGemm or (IsGroupGemm and IsAuxOutEnabled));
|
||||
static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled));
|
||||
|
||||
if constexpr (IsAuxOutEnabled) {
|
||||
for (int32_t i = 0; i < L; ++i) {
|
||||
@ -1323,8 +1323,16 @@ struct HostCollectiveEpilogue {
|
||||
cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch]));
|
||||
auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensors_Aux[batch].host_data() : references_Aux[batch].host_data()),
|
||||
cute::make_layout(cute::make_shape(M, N, 1), stride_Aux));
|
||||
auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux);
|
||||
auto Aux = [&]() {
|
||||
auto ptr = recast_ptr<ElementAux>(nullptr);
|
||||
if (IsAuxInEnabled) {
|
||||
ptr = detail::make_iterator(tensors_Aux[batch].host_data());
|
||||
} else if (IsAuxOutEnabled) {
|
||||
ptr = detail::make_iterator(references_Aux[batch].host_data());
|
||||
}
|
||||
return cute::make_tensor(ptr, Aux_layout);
|
||||
}();
|
||||
auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()),
|
||||
cute::make_layout(cute::make_shape(M, cute::_1{})));
|
||||
auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()),
|
||||
|
||||
@ -78,7 +78,69 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // M
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm, 128x128x64_2x2x1_direct_store) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
@ -115,6 +177,8 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
@ -0,0 +1,184 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide Ptr-Array Ping-pong scheduler GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x_ptr_array.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1_direct_store) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
@ -115,9 +115,11 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) {
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_direct_store) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
@ -173,6 +175,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(TestAll<Gemm>(1.0, 1.0));
|
||||
EXPECT_TRUE(TestAll<Gemm>(1.0, 0.0));
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,182 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for device-wide Ptr-Array GEMM interface
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x_ptr_array.hpp"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
bool result = TestAll<Gemm>(1.0, 1.0);
|
||||
EXPECT_TRUE(result);
|
||||
result = TestAll<Gemm>(1.0, 0.0);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1_direct_store) {
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cutlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using namespace test::gemm::device;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(TestAll<Gemm>(1.0, 1.0));
|
||||
EXPECT_TRUE(TestAll<Gemm>(1.0, 0.0));
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
Reference in New Issue
Block a user