diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 7a191ce2..5181678c 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // M using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size -using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch -using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, - ElementC, LayoutC, AlignmentC, - ElementC, LayoutC, AlignmentC, - EpilogueSchedule - >::CollectiveOp; +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_64>; + using ClusterShape = Shape<_1,_2,_1>; +}; -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule - >::CollectiveOp; +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_64,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; +}; -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cutlass::gemm::ArrayProblemShape>, - CollectiveMainloop, - CollectiveEpilogue ->; +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; // Reference device GEMM implementation type using DeviceGemmReference = cutlass::reference::device::Gemm< @@ -261,14 +287,14 @@ bool initialize_block( int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { - scope_max = 2; - scope_min = 0; + scope_max = static_cast(2); + scope_min = static_cast(0); } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = static_cast(2); + scope_min = static_cast(-2); } else { - scope_max = 8; - scope_min = -8; + scope_max = static_cast(8); + scope_min = static_cast(-8); } cutlass::reference::device::BlockFillRandomUniform( @@ -351,7 +377,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +template +typename GemmT::Arguments args_from_options(const Options &options) { cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish @@ -359,7 +386,7 @@ typename Gemm::Arguments args_from_options(const Options &options) hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::Arguments arguments{ + typename GemmT::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, {{options.m, options.n, options.k, options.l}}, {ptr_A.get(), stride_A, ptr_B.get(), stride_B}, @@ -405,20 +432,20 @@ bool verify(const Options &options) { } /// Execute a given example GEMM computation -template +template int run(Options &options) { allocate(options); initialize(options); // Instantiate CUTLASS kernel depending on templates - Gemm gemm; + GemmT gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options); + auto arguments = args_from_options(options); // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); + size_t workspace_size = GemmT::get_workspace_size(arguments); // Allocate workspace memory cutlass::device_memory::allocation workspace(workspace_size); @@ -510,7 +537,10 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; run(options); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); #endif return 0; diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index f9467956..a26d904d 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // A using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size -using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch -using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_128>; + using ClusterShape = Shape<_2,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC *, AlignmentC, ElementC, LayoutC *, AlignmentC, - EpilogueSchedule + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder KernelSchedule >::CollectiveOp; -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue ->; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; // Reference device GEMM implementation type using DeviceGemmReference = cutlass::reference::device::Gemm< @@ -271,10 +297,10 @@ struct Options { int n = cmd_line_n; int k = cmd_line_k; if (m < 1) { - m = ((rand() % 512) + 1); + m = alignment * ((rand() % 64) + 1); } if (n < 1) { - n = ((rand() % 512) + 1); + n = alignment * ((rand() % 64) + 1); } if (k < 1) { k = alignment * ((rand() % 64) + 1); @@ -521,7 +547,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) +template +typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) { cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish @@ -529,33 +556,49 @@ typename Gemm::Arguments args_from_options(const Options &options, bool host_pro hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::EpilogueOutputOp::Params params; + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. - params = typename Gemm::EpilogueOutputOp::Params( - ElementAccumulator(options.alpha), ElementAccumulator(options.beta)); + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; } else { // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. - params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get()); + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; } - typename Gemm::Arguments arguments; if (host_problem_shapes_available) { - arguments = typename Gemm::Arguments { + arguments = typename GemmT::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; } else { - arguments = typename Gemm::Arguments { + arguments = typename GemmT::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; } @@ -605,20 +648,20 @@ bool verify(const Options &options) { } /// Execute a given example GEMM computation -template +template int run(Options &options, bool host_problem_shapes_available = true) { allocate(options); initialize(options); // Instantiate CUTLASS kernel depending on templates - Gemm gemm; + GemmT gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options, host_problem_shapes_available); + auto arguments = args_from_options(options, host_problem_shapes_available); // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); + size_t workspace_size = GemmT::get_workspace_size(arguments); // Allocate workspace memory cutlass::device_memory::allocation workspace(workspace_size); @@ -713,8 +756,14 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; run(options); + std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl; run(options, false /*host_problem_shapes_available*/); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); + std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl; + run(options, false /*host_problem_shapes_available*/); #endif return 0; diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt index 2c3ff3a4..1dadbfa8 100644 --- a/examples/57_hopper_grouped_gemm/CMakeLists.txt +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -32,10 +32,10 @@ set(TEST_RANDOM --iterations=0) # Random problem sizes set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes -set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes -set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index f12cdb59..f0eb5511 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -274,4 +274,18 @@ struct conditional_template { using type = False; }; +// +// is_any_of +// + +/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us +template +struct is_any_of { + constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); +}; + +/// Is true if and only if T is same as (is_same_v) at least one of the types in Us +template +inline constexpr bool is_any_of_v = is_any_of::value; + } // end namespace cute diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 2ca62c97..90a60002 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -71,14 +71,18 @@ sm90_get_tma_dispatch_policy() { // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs - constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_tma_ptr_array_v; + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_ptr_array_tma_v; constexpr int StagesD = cute::min(EpiTiles, 2); constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) : cute::min(EpiTiles, 4); - return cute::conditional_t, - Sm90PtrArrayTmaWarpSpecialized, - Sm90TmaWarpSpecialized>{}; + if constexpr (detail::sm90_is_ptr_array_tma_v) { + return Sm90PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm90TmaWarpSpecialized{}; + } } // Returns the smem layout atom to be used for C or D matrix @@ -255,6 +259,9 @@ struct Sm90TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + using CopyOpS2G = cute::conditional_t, SM90_TMA_STORE_IM2COL, SM90_TMA_STORE @@ -267,17 +274,11 @@ struct Sm90TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; - using FusionDispatchPolicy = Sm90TmaWarpSpecialized; - // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination using FusionCallbacks = typename CallbacksBuilder< - FusionDispatchPolicy, + DispatchPolicy, FusionOpOrCallbacks, TileShape_MNK, EpilogueTile_MN, @@ -294,11 +295,11 @@ struct Sm90TmaBuilderImpl { GmemStrideTypeD, FusionCallbacks, CopyOpG2S, - decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_load_op_for_source()), + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), CopyOpS2G, - decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_store_op_for_accumulator()), + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), CopyAtomC >; }; @@ -483,7 +484,7 @@ struct CollectiveBuilder< FusionOperation, cute::enable_if_t || cute::is_same_v || - cute::is_same_v >> { + detail::sm90_is_ptr_array_tma_v>> { private: using ElementD = cute::conditional_t, fusion::get_element_aux_t, ElementD_>; diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index a6e5e2f4..b96b13fe 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -71,6 +71,62 @@ is_im2col() { || cute::is_same_v>; } +template +struct sm90_is_ptr_array_tma : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma::value; + +template +struct sm90_is_ptr_array_tma_cooperative : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_cooperative : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative::value; + +template +struct sm90_is_ptr_array_tma_pingpong : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_pingpong : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong::value; + +template +struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {}; + +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm90PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; + using cutlass::atomic_maximum; template @@ -79,14 +135,11 @@ static constexpr int elements_per_access_v = cutlass::sizeof_bits::val template static constexpr bool sm90_is_cooperative_v = cute::is_base_of_v || - cute::is_base_of_v; - -template -static constexpr bool sm90_is_tma_ptr_array_v = - cute::is_base_of_v; + sm90_is_ptr_array_tma_cooperative_v; template static constexpr bool sm90_is_warp_specialized_v = + (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || cute::is_base_of_v; template @@ -199,7 +252,11 @@ public: } CUTLASS_DEVICE auto - load_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, [[maybe_unused]] int32_t const sm_idx) const { + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx) { return cute::make_tuple(nullptr); } @@ -243,7 +300,7 @@ public: [[maybe_unused]] TensorStorage& shared_tensors, [[maybe_unused]] TensorMapC const& load_tensormap, [[maybe_unused]] int subtile_idx=-1, - [[maybe_unused]] bool return_prior_state = false) + [[maybe_unused]] bool wait = false) { return load_pipe_producer_state; } @@ -257,8 +314,12 @@ public: } CUTLASS_DEVICE auto - store_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx) { return cute::make_tuple(nullptr); } @@ -369,22 +430,25 @@ public: // Dummy methods to perform different parts of TMA/Tensormap modifications - template + template CUTLASS_DEVICE void tensormaps_perform_update( - [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] int32_t next_batch) { } + [[maybe_unused]] ProblemShapeMNKL problem_shape, + [[maybe_unused]] int32_t next_batch, + [[maybe_unused]] int32_t warp_group_idx) { } template CUTLASS_DEVICE void tensormaps_cp_fence_release( - [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate) { } + [[maybe_unused]] uint32_t lane_predicate, + [[maybe_unused]] int32_t warp_group_idx) { } template CUTLASS_DEVICE diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 87e62887..87b67867 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -44,9 +44,10 @@ #include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" #include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -62,6 +63,7 @@ template < int FragmentSize_, bool ReuseSmemC_, bool DelayTmaStore_, + int NumEpilogueWarpGroups_, class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) class ElementC_, @@ -78,7 +80,13 @@ template < class CopyAtomC_ > class CollectiveEpilogue< - Sm90PtrArrayTmaWarpSpecialized, + Sm90PtrArrayTmaWarpSpecialized, CtaTileMNK_, EpilogueTile_, ElementC_, @@ -98,7 +106,13 @@ public: // // Type Aliases // - using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; + using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; using CtaTileMNK = CtaTileMNK_; using EpilogueTile = EpilogueTile_; using FusionCallbacks = FusionCallbacks_; @@ -201,6 +215,8 @@ public: (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; constexpr static bool RequiresTransactionBytes = true; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; + // TMA pipeline for storing D using StorePipeline = cute::conditional_t, @@ -219,7 +235,7 @@ public: struct TensorMapStorage : cute::aligned_struct<128> { cute::TmaDescriptor smem_tensormap_C; - cute::TmaDescriptor smem_tensormap_D; + cute::array smem_tensormap_D; } tensormaps; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -229,6 +245,8 @@ public: using TensorMapStorage = typename SharedStorage::TensorMapStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; @@ -261,7 +279,9 @@ public: TMA_D tma_store_d; cute::TmaDescriptor* tensormaps; ElementC const** ptr_C; + StrideC dC; ElementD** ptr_D; + StrideD dD; uint32_t tma_transaction_bytes = TmaTransactionBytes; }; @@ -275,36 +295,57 @@ public: ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); - auto [M, N, K, mock_L] = problem_shape_MNKL; - // Manage batches/groups through pointers to input matricies - mock_L = 1; + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_L = get<3>(init_shape); static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + init_L = get<3>(problem_shape_MNKL); + + stride_c = args.dC; + stride_d = args.dD; + } + uint32_t transaction_bytes = TmaTransactionBytes; typename Params::TMA_C tma_load_c = {}; if constexpr (is_source_supported) { ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); - Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC, _0{}))); - tma_load_c = make_tma_copy_C_sm90( + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy( CopyOpG2S{}, tensor_c, take<0,2>(SmemLayoutC{}), - EpilogueTile{}); + EpilogueTile{}, + _1{}); } typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { ElementD const* ptr_D_first_batch = reinterpret_cast(args.ptr_D); - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD, _0{}))); - tma_store_d = make_tma_copy_C_sm90( + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy( CopyOpS2G{}, tensor_d, take<0,2>(SmemLayoutD{}), - EpilogueTile{}); + EpilogueTile{}, + _1{}); } auto fusion_workspace = static_cast(workspace); @@ -318,7 +359,9 @@ public: tma_store_d, tma_descriptor_workspace, args.ptr_C, + args.dC, args.ptr_D, + args.dD, transaction_bytes, }; } @@ -326,10 +369,11 @@ public: template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + auto descriptors_shape = cute::make_shape(sm_count, Int{}); constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); + return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); } template @@ -342,30 +386,40 @@ public: template static bool can_implement( - ProblemShape const& problem_shape, + ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); - auto [M,N,K,L] = problem_shape_MNKL; bool implementable = true; - if constexpr (is_destination_supported) { - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); - } + bool fusion_implementable = true; - if constexpr (not cute::is_void_v) { - constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); } if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); } @@ -414,10 +468,14 @@ public: } CUTLASS_DEVICE auto - load_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { + load_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { // Initialize tma for loading constexpr bool IsLoad = true; - auto load_tensormaps = tensormaps_init(params, sm_count, sm_idx); + auto load_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0); return load_tensormaps; } @@ -426,7 +484,8 @@ public: class TileShapeMNK, class TileCoordMNKL, class TiledMma, - class TensorMapC + class TensorMapC, + __CUTE_REQUIRES(std::is_pointer_v) > CUTLASS_DEVICE auto load( @@ -440,7 +499,7 @@ public: TensorStorage& shared_tensors, TensorMapC const& load_tensormap, int subtile_idx=-1, - bool return_prior_state = false) { + bool wait_until_load_finishes = false) { using namespace cute; // Indexing variables @@ -478,17 +537,21 @@ public: auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + LoadPipelineState last_load_producer_state = load_pipe_producer_state; + // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); // Acquire the lock for the first stage - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); load_pipeline.producer_acquire(load_pipe_producer_state); + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); // Pre-loop fusion callback entry point pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); - auto prior_state = load_pipe_producer_state; + LoadPipelineState prior_state = load_pipe_producer_state; + + bool did_load = false; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { @@ -506,15 +569,18 @@ public: pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); // Execute the TMA load for C if needed - if (issue_tma_load && is_C_load_needed) { - copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); + if (is_C_load_needed) { + if (issue_tma_load) { + copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + last_load_producer_state = load_pipe_producer_state; + did_load = true; } // Commit TMA loads for this stage and release the lock load_pipeline.producer_commit(load_pipe_producer_state); - prior_state = load_pipe_producer_state; ++load_pipe_producer_state; } } @@ -522,17 +588,24 @@ public: // Post-loop fusion callback entry point pld_callbacks.end(); - if (not return_prior_state) { - return load_pipe_producer_state; - } else { - return prior_state; + if (wait_until_load_finishes && did_load) { + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; + load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); } + + return load_pipe_producer_state; } CUTLASS_DEVICE auto load_tail( LoadPipeline load_pipeline, LoadPipelineState load_pipe_producer_state) { + + if (!fusion_callbacks.is_producer_load_needed()) { + return load_pipe_producer_state; + } + bool issue_tma_load = cute::elect_one_sync(); if (issue_tma_load) { load_pipeline.producer_tail(load_pipe_producer_state); @@ -564,6 +637,7 @@ public: TensorStorage& shared_tensors, TensorMapD const& store_tensormap, int subtile_idx=-1) { + using namespace cute; using ElementAccumulator = typename AccEngine::value_type; using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; @@ -869,11 +943,22 @@ public: } CUTLASS_DEVICE auto - store_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { - // Initialize tma - constexpr bool IsLoad = false; - auto store_tensormaps = tensormaps_init(params, sm_count, sm_idx); - return store_tensormaps; + store_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; + // Since only one warp issues TMA store, we only need that one warp to initialize tensormaps + if (warp_idx_in_warp_group == 0) { + // Initialize tma + constexpr bool IsLoad = false; + auto store_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx); + return store_tensormaps; + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); } // @@ -882,53 +967,45 @@ public: template CUTLASS_DEVICE auto - tensormaps_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { - cute::TmaDescriptor* tma_desc = nullptr; - cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + tensormaps_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + Layout desc_layout = make_layout(make_shape(sm_count, Int{})); + + Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) + if constexpr (IsLoad) { if (not cute::is_void_v) { - tma_desc = &gmem_tensormap[sm_idx]; + constexpr int C_tensormap_index = NumEpilogueWarpGroups; + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later - Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gC_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); - copy(recast(pC_tensormap), recast(gC_tensormap)); + // Bringing tensormaps from params to smem for modification later + copy(recast(pC_tensormap), recast(sC_tensormap)); } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); } - } else { - int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; - tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + else { + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later - Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gD_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); - copy(recast(pD_tensormap), recast(gD_tensormap)); + // Bringing tensormaps from params to smem for modification later + copy(recast(pD_tensormap), recast(sD_tensormap)); } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); } - - return cute::make_tuple(tma_desc); - } - - // Bringing tensormaps to smem (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_fetch_to_smem( - TensorMapStorage& shared_tensormap, - cute::TmaDescriptor const* tensormap) const { - if constexpr (IsLoad) { - if (not cute::is_void_v) { - Tensor gC_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); - Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); - copy(recast(gC_tensormap), recast(sC_tensormap)); - } - } else { - Tensor gD_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); - Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); - copy(recast(gD_tensormap), recast(sD_tensormap)); - } - cp_async_fence(); - cp_async_wait<0>(); } // Replace address for the global tensor (to be done by single thread) @@ -936,35 +1013,99 @@ public: CUTLASS_DEVICE void tensormaps_replace_global_address( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& params, - int32_t next_batch) { + int32_t next_batch, + int32_t warp_group_idx) { // Replacing global_address for the next batch if constexpr (IsLoad) { - if (not cute::is_void_v) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, + if constexpr (is_source_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, params.ptr_C[next_batch]); } - } else { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, + } + else if constexpr (is_destination_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], params.ptr_D[next_batch]); } } - template + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl, + int32_t warp_group_idx) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Only consider dimensions and strides that we need to recalculate and replace for each group + constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N + static_assert(TensorRank == Int<3>{}, + "Descriptor modification for global dims & strides expects rank as 3."); + + cute::array prob_shape = {1,1,1}; + cute::array prob_stride = {0,0,0}; + + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } + } + else if constexpr (is_destination_supported) { + ElementD const* ptr_D = nullptr; + + // tma_store_c should be a gmem_tensor, second argument should be a stride + + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + prob_shape, + prob_stride); + } + } + + template CUTLASS_DEVICE void tensormaps_perform_update( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& params, cute::TmaDescriptor const* tensormap, - int32_t next_batch) { + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) { if (cute::elect_one_sync()) { - // Bringing tensormaps to smem - tensormaps_fetch_to_smem(shared_tensormap, tensormap); // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormap, params, next_batch); + tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); + } } } @@ -972,16 +1113,18 @@ public: CUTLASS_DEVICE void tensormaps_cp_fence_release( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate) { + [[maybe_unused]] uint32_t lane_predicate, + int32_t warp_group_idx = 0) { // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { - if (not cute::is_void_v) { - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); + if constexpr (is_source_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); } - } else { - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); + } + else if constexpr (is_destination_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); } } @@ -990,10 +1133,11 @@ public: void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if (not cute::is_void_v) { + if constexpr (not cute::is_void_v) { cute::tma_descriptor_fence_acquire(tensormap); } - } else { + } + else { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 9f9576b4..e96f4134 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -51,7 +51,21 @@ struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; -struct PtrArrayTmaWarpSpecializedCooperative {}; + +struct PtrArrayTmaWarpSpecializedCooperative { + static constexpr int NumEpilogueWarpGroups = 2; +}; + +// Standard warp specialized epilogue +struct PtrArrayTmaWarpSpecialized { + static constexpr int NumEpilogueWarpGroups = 1; +}; + +// Pingpong kernel epilogue +struct PtrArrayTmaWarpSpecializedPingpong { + static constexpr int NumEpilogueWarpGroups = 2; +}; + // DEPRECATED schedules, will be removed in next release struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; @@ -151,7 +165,8 @@ template< int StagesD_, int FragmentSize_, bool ReuseSmemC_, - bool DelayTmaStore_ + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ > struct Sm90PtrArrayTmaWarpSpecialized { constexpr static int StagesC = StagesC_; @@ -159,6 +174,7 @@ struct Sm90PtrArrayTmaWarpSpecialized { constexpr static int FragmentSize = FragmentSize_; constexpr static bool ReuseSmemC = ReuseSmemC_; constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; }; // DEPRECATED policies, will be removed in next release diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index a0128877..0bfacf34 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -32,6 +32,7 @@ #pragma once #include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 3f43f60d..ece5ac54 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -179,6 +179,89 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombinationPtrArray = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastPtrArray>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int>; + using StrideBeta = Stride<_0,_0,int>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = activation(alpha * acc + beta * C) template< template class ActivationFn, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 4eb326b3..aedacb55 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cute/tensor.hpp" #include "sm90_visitor_tma_warpspecialized.hpp" @@ -514,7 +515,8 @@ private: if (params_ptr->scalar_ptrs[0] != nullptr) { scalar = params_ptr->scalar_ptrs[0][l_offset]; - } else { + } + else { // batch stride is ignored for nullptr fallback scalar = params_ptr->scalars[0]; } @@ -541,6 +543,169 @@ private: } }; +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcastPtrArray { + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + // producer load is needed if Element is not void and we have multiple scalars + return !cute::is_void_v and size<2>(params_ptr->dScalar[0]) != 0; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + return scalar == Element(0); + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptr_arrays[0] != nullptr) { + scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]); + } + else if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + + if (params_ptr->scalar_ptr_arrays[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset])); + } + if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 25b1f848..0b3ecb15 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -31,6 +31,10 @@ #pragma once #include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +#include "cutlass/gemm/collective/collective_builder_decl.hpp" // SM90 Collective Builders should be used only starting CUDA 12.0 #if (__CUDACC_VER_MAJOR__ >= 12) @@ -177,10 +181,12 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - (cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v) && + (cute::is_any_of_v) && not detail::is_use_rmem_A()> > { static_assert(is_static::value); @@ -191,10 +197,12 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), - "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + "KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -203,8 +211,10 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v || IsArrayOfPointersGemm, + static constexpr bool IsCooperative = cute::is_any_of_v; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< @@ -218,7 +228,10 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, @@ -505,10 +518,12 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v> + cute::is_any_of_v> > { static_assert(is_static::value); static_assert(is_static::value); @@ -526,10 +541,15 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); - using AtomLayoutMNK = cute::conditional_t || - IsArrayOfPointersGemm, - Layout>, Layout>>; + static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v; + + static constexpr bool IsCooperative = cute::is_any_of_v; + + using AtomLayoutMNK = cute::conditional_t>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); @@ -542,7 +562,11 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(TensorMapStorage); + static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 4f2837d1..75d7bb39 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -623,56 +623,42 @@ struct CollectiveMma< // CUTLASS_DEVICE auto - tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later + // Bringing tensormaps from params to smem for modification later Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); - copy(recast(pA_tensormap), recast(gA_tensormap)); - copy(recast(pB_tensormap), recast(gB_tensormap)); + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); } + __syncwarp(); return cute::make_tuple(tma_desc_a, tma_desc_b); } - // Bringing tensormaps to smem (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_fetch_to_smem( - TensorMapStorage& shared_tensormap, - cute::tuple const& input_tensormaps) const { - Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); - Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); - Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); - Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); - - copy(recast(gA_tensormap), recast(sA_tensormap)); - copy(recast(gB_tensormap), recast(sB_tensormap)); - - cp_async_fence(); - cp_async_wait<0>(); - } - // Replace address for the global tensor (to be done by single thread) CUTLASS_DEVICE void tensormaps_replace_global_address( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_batch) { // Replacing global_address for the next batch - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]); - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, mainloop_params.ptr_B[next_batch]); } @@ -681,7 +667,7 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_replace_global_tensor_properties( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_group, ProblemShape_MNKL problem_shape_mnkl) { @@ -716,10 +702,10 @@ struct CollectiveMma< stride = (stride * sizeof_bits_v) / 8; } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A, + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A); - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B, + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, prob_shape_B, prob_stride_B); } @@ -728,21 +714,19 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_perform_update( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch) { if (cute::elect_one_sync()) { - // Bringing tensormaps to smem - tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); if constexpr (IsGroupedGemmKernel) { // Replacing global dims and strides for the next batch - tensormaps_replace_global_tensor_properties(shared_tensormap, + tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_mnkl); } } @@ -752,11 +736,11 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_cp_fence_release ( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, cute::tuple const& input_tensormaps) { // Entire warp must do this (i.e. it's aligned) - tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); - tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); } // The entire warp must call this function collectively (that is, the instructions are aligned) diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 2e820b61..c1c2308b 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -98,6 +98,7 @@ struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpong { }; ////////////////////////////////////////////////////////////////////////////// @@ -111,6 +112,7 @@ struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum : KernelPtrArrayTmaWarpSpecializedPingpong { }; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -286,8 +288,9 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized { using ArchTag = arch::Sm90; using Schedule = KernelSchedule; static_assert( - cute::is_base_of_v, - "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); + cute::is_base_of_v || + cute::is_base_of_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies"); }; ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index b682be86..6c7b89a2 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -61,5 +61,6 @@ struct IsCutlass3ArrayKernel or "); + + static_assert(cute::is_base_of_v); + // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; @@ -119,8 +124,9 @@ public: using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; /// Register requirement for Load and Math WGs @@ -215,11 +221,11 @@ public: workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* mainloop_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used @@ -275,9 +281,6 @@ public: args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -286,6 +289,9 @@ public: sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); @@ -363,6 +369,12 @@ public: static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); static_assert(size<0>(TileShape{}) >= 128, "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + static_assert(NumMmaWarpGroups == 2, "Cooperative kernels currently only support NumMmaWarpGroups == 2"); + + if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { + static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, + "Tiled MmA does not match expected warp groups performing the epilogue"); + } static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -391,7 +403,8 @@ public: int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto warp_group_idx = canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); @@ -466,7 +479,9 @@ public: // Get the appropriate blocks for this thread block -- potential for thread block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); TileScheduler scheduler{params.scheduler}; @@ -484,7 +499,7 @@ public: } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); // Prepare and partition the input tensors. Expects a tuple of tensors where: // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) @@ -510,7 +525,7 @@ public: int32_t const sm_count = params.hw_info.sm_count; // Fetch a copy of tensormaps for the CTA - auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, sm_count, sm_idx); + auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); // Update tensormap for the initial batch for the CTA if (work_tile_info.is_valid()) { @@ -578,7 +593,7 @@ public: if (work_tile_info.is_valid() && did_batch_change) { curr_batch = next_batch; if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); } // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates // Since this state is waiting for loads to finish, it must start in the inverted phase. @@ -610,7 +625,7 @@ public: int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); int32_t const sm_count = params.hw_info.sm_count; - auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, sm_count, sm_idx)); + auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = true; @@ -620,23 +635,26 @@ public: shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, - work_tile_info.L_idx + problem_shape_MNKL, + work_tile_info.L_idx, + 0 ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); } load_order_barrier.wait(); while (work_tile_info.is_valid()) { int32_t curr_batch = work_tile_info.L_idx; - bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); - if (compute_epilogue) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); } // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape @@ -649,6 +667,8 @@ public: collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); } + bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; + epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -660,36 +680,34 @@ public: shared_storage.tensors.epilogue, epi_load_tensormap, work_tile_info.reduction_subtile_idx(), - true // return state prior to last advance + wait ); } - // Get next work tile - work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { - // Wait for TMA load to finish before updating - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = - {epi_load_pipe_producer_state.index(), !epi_load_pipe_producer_state.phase(), epi_load_pipe_producer_state.count()}; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } - epi_load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); + // tensormap update + { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - work_tile_info.L_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); - } - - if(compute_epilogue) { - epi_load_pipe_producer_state.advance(1); + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } } } // Scheduler work fetch loop @@ -702,32 +720,43 @@ public: else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); int32_t const sm_count = params.hw_info.sm_count; // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it bool do_store_tail = false; // Get a copy of tensormaps - auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, sm_count, sm_idx)); + auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = false; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - work_tile_info.L_idx - ); + if (warp_idx_in_warp_group == 0) { - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } } while (work_tile_info.is_valid()) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); } int32_t curr_batch = work_tile_info.L_idx; @@ -743,6 +772,10 @@ public: // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + static_assert(cute::is_any_of_v, + detail::PersistentTileSchedulerSm90>); if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { collective_mainloop.mma( mainloop_pipeline, @@ -764,18 +797,16 @@ public: // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(work_k_tile_count); } - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; // Perform reduction across splits, if needed TileScheduler::fixup( params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + } - if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); - } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = @@ -804,20 +835,31 @@ public: did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - work_tile_info.L_idx - ); + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + if (warp_idx_in_warp_group == 0) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } } } // Scheduler work fetch loop + // Cooperative only needs TMA to complete at the very end of the kernel if (do_store_tail) { collective_epilogue.store_tail( epi_load_pipeline, @@ -829,7 +871,6 @@ public: } // Consumer Warp Groups End #endif } - }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp new file mode 100644 index 00000000..491ec0ec --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,949 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t> +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + static_assert(cute::is_base_of_v); + + static constexpr bool IsGdcEnabled = false; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + static_assert(cute::is_void_v, + "Ptr-Array Pingpong and Grouped Gemm Pingpong kernel only supports the default scheduler."); + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TileScheduler = cute::conditional_t::Scheduler, + typename detail::TileSchedulerSelector< + void, ArchTag, TileShape, ClusterShape>::Scheduler>; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + using MathWarpGroupOrderBarrierSharedStorage = cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage< + MathWarpGroupOrderBarrier::SequenceDepth, + MathWarpGroupOrderBarrier::SequenceLength>; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + } pipelines; + + struct TensorMapStorage : cute::aligned_struct<128> { + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + + alignas(128) MainloopTensorMapStorage mainloop; + alignas(128) EpilogueTensorMapStorage epilogue; + } tensormaps; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + ProblemShape problem_shapes = args.problem_shape; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + else { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); + static_assert(NumMmaWarpGroups == 2, "Pingpong kernels currently only support NumMmaWarpGroups == 2"); + + if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { + static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, + "Tiled MmA does not match expected warp groups performing the epilogue"); + } + + static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_idx = canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Note: Tma Descriptor Prefetch (from either const or param) is not applicable here + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = warp_group_idx - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + + if (warp_group_role == WarpGroupRole::Consumer1) { + // Advance 2nd Math WG to the next work tile for the startup + const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + if (!work_tile_info.is_valid()) { + return; + } + + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; + int32_t const mock_l_coord = 0; + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + // Fetch a copy of tensormaps for the CTA + auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); + + // Update tensormap for the initial batch for the CTA + if (work_tile_info.is_valid()) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + curr_batch + ); + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (i.e. it's aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + + bool do_load_order_arrive = true; + bool did_batch_change = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + if (did_batch_change) { + collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + } + + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + input_tensormaps, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + // Wait for the last TMA stage to complete loading, before issuing tensormap updates + mainloop_pipe_producer_state.advance(work_k_tile_count - 1); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx + did_batch_change = next_batch != curr_batch; + if (work_tile_info.is_valid() && did_batch_change) { + curr_batch = next_batch; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); + } + // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates + // Since this state is waiting for loads to finish, it must start in the inverted phase. + typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state = + {mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()}; + mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + curr_batch + ); + // Ensure warp is converged before issuing tensor replace + __syncwarp(); + // Entire warp must do this (i.e. it's aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + // Advance the producer state for the last remaining stage that was being waited for above + mainloop_pipe_producer_state.advance(1); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + if (work_tile_info.is_valid()) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } + + load_order_barrier.wait(); + + while (work_tile_info.is_valid()) { + int32_t curr_batch = work_tile_info.L_idx; + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + } + + bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue, + epi_load_tensormap, + work_tile_info.reduction_subtile_idx(), + wait + ); + } + + work_tile_info = next_work_tile_info; + did_batch_change = curr_batch != work_tile_info.L_idx; + + if (work_tile_info.is_valid() && did_batch_change) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // tensormap update + { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } + } + + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; + + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + // Get a copy of tensormaps + auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + + if (work_tile_info.is_valid()) { + + if (warp_idx_in_warp_group == 0) { + + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } + } + + while (work_tile_info.is_valid()) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + int32_t curr_batch = work_tile_info.L_idx; + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + static_assert(cute::is_any_of_v, + detail::PersistentTileSchedulerSm90>); + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + + math_wg_order_barrier.wait(); + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + math_wg_order_barrier.wait(); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); + + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + epi_store_tensormap, + work_tile_info.reduction_subtile_idx() + ); + + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + + // Skip a tile for pingpong + if (work_tile_info.is_valid()) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + mainloop_pipe_consumer_state.advance(work_k_tile_count); + + // Go to next tile + auto next_next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + + work_tile_info = next_next_work_tile_info; + } + + did_batch_change = curr_batch != work_tile_info.L_idx; + if (work_tile_info.is_valid() && did_batch_change) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + if (warp_idx_in_warp_group == 0) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } + } + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel \ No newline at end of file diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index a70ce542..348b185c 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -338,12 +338,14 @@ cutlass_test_unit_add_executable( cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu ) # Group Gemm test cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu ) # Fused epilogue tests diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index e2d3f2d0..085d3e74 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -1005,7 +1005,7 @@ struct HostCollectiveEpilogue { stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); } - static_assert(!IsGroupGemm or (IsGroupGemm and IsAuxOutEnabled)); + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); if constexpr (IsAuxOutEnabled) { for (int32_t i = 0; i < L; ++i) { @@ -1323,8 +1323,16 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensors_Aux[batch].host_data() : references_Aux[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_Aux)); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu index b93d9368..2a6d2339 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu @@ -78,7 +78,69 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // M using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch @@ -115,6 +177,8 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using Gemm = cutlass::gemm::device::GemmUniversalAdapter; bool result = TestAll(1.0, 1.0); EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); } #endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu new file mode 100644 index 00000000..09be8a49 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu @@ -0,0 +1,184 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array Ping-pong scheduler GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) \ No newline at end of file diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu index dc581acf..53748dc8 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu @@ -115,9 +115,11 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using Gemm = cutlass::gemm::device::GemmUniversalAdapter; bool result = TestAll(1.0, 1.0); EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); } -TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) { +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_direct_store) { // A matrix configuration using ElementA = cutlass::half_t; // Element type for A matrix operand @@ -173,6 +175,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using namespace test::gemm::device; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(TestAll(1.0, 1.0)); EXPECT_TRUE(TestAll(1.0, 0.0)); } diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu new file mode 100644 index 00000000..5b91825c --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(TestAll(1.0, 1.0)); + EXPECT_TRUE(TestAll(1.0, 0.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) \ No newline at end of file