diff --git a/CHANGELOG.md b/CHANGELOG.md index 00728725..ed464f61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,15 +13,17 @@ - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu). - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). -* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu). +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/): both [forward](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) and [backward](./examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu) passes are supported. * A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler. - Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e6f298e..b54b8335 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -713,6 +713,7 @@ target_include_directories( CUTLASS SYSTEM INTERFACE $ + $ ) install( diff --git a/README.md b/README.md index 433c375c..50fae016 100644 --- a/README.md +++ b/README.md @@ -50,15 +50,17 @@ architecture. - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu). - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). -* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu). +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/): both [forward](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) and [backward](./examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu) passes are supported. * A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler. - Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu index c963d95e..025eb65f 100644 --- a/examples/04_tile_iterator/tile_iterator.cu +++ b/examples/04_tile_iterator/tile_iterator.cu @@ -34,7 +34,7 @@ addressable memory, and then store it back into addressable memory. TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to - and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines + and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined thread mappings to be specified. diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index 1c21678f..5d4fe1a1 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -75,11 +75,11 @@ #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_blockwise_scaling.h" using namespace cute; @@ -123,7 +123,13 @@ using ArchTag = cutlass::arch::Sm90; // T using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>; + +using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(TileShape{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -143,8 +149,8 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -190,20 +196,22 @@ StrideB stride_B; StrideC stride_C; StrideD stride_D; StrideAux stride_aux; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; uint64_t seed; +using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; cutlass::HostTensor tensor_D; uint32_t mma_promotion_interval; -cutlass::HostTensor blockscale_tensor_A; -cutlass::HostTensor blockscale_tensor_B; +cutlass::HostTensor blockscale_tensor_A; +cutlass::HostTensor blockscale_tensor_B; cutlass::HostTensor tensor_ref_D; cutlass::HostTensor tensor_aux; cutlass::HostTensor tensor_ref_aux; -using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor scalar_alpha; cutlass::HostTensor scalar_beta; cutlass::HostTensor scale_A; @@ -342,26 +350,25 @@ bool initialize_scale_tensor( /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { - // Find Block Scaling tensor shapes based on problem shape and TileShape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); stride_aux = stride_D; + // Layout SFA and SFB represent logically broadcasting data in CuTe. + // E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK)) + // and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK + // indecies in the tensor map to the same offset. + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); - auto blockscale_a_coord = cutlass::make_Coord(blockscale_m * options.l, blockscale_k); - auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); tensor_A.resize(a_coord); blockscale_tensor_A.resize(blockscale_a_coord); @@ -465,7 +472,9 @@ typename Gemm::Arguments args_from_options(const Options &op stride_B, mma_promotion_interval, blockscale_tensor_A.device_data(), - blockscale_tensor_B.device_data() + layout_SFA, + blockscale_tensor_B.device_data(), + layout_SFB }, { {}, // epilogue.thread @@ -519,12 +528,6 @@ bool verify(const Options &options) { // Compute reference output // - // Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_m = ceil_div(options.m, get<0>(TileShape{})); - auto blockscale_n = ceil_div(options.n, get<1>(TileShape{})); - auto blockscale_k = ceil_div(options.k, get<2>(TileShape{})); - // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), cute::make_layout( @@ -557,28 +560,18 @@ bool verify(const Options &options) { ) ); - auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), - cute::make_layout( - cute::make_shape(blockscale_m, blockscale_k, options.l), - cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), - cute::make_layout( - cute::make_shape(blockscale_n, blockscale_k, options.l), - cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA); + auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Blockwise scaling Tensors - }; + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index b7cdb00a..096e56a6 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -75,11 +75,11 @@ #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_groupwise_scaling.h" using namespace cute; @@ -120,55 +120,30 @@ using ElementAccumulator = float; // E using ElementBlockScale = float; // Element type for blockscaling during accumulation using ElementCompute = float; // Element type for epilogue computation -using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor -// Given TileShape = Shape<_128,_128,_128>: -// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) -// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling -// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling -// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling -template -struct GroupScaleConfig { - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size - using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; - static constexpr int ScaleGranularityM = ScaleGranularityM_; - static constexpr int ScaleGranularityN = ScaleGranularityN_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; +constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; - static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, - "FP8 scaling granularity must evenly divide tile shape along M."); - static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, - "FP8 scaling granularity must evenly divide tile shape along N."); +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; -}; -using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; -using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; -using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; -using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; - -template -struct GroupScaleGemm { - using ArchTag = typename ScheduleConfig::ArchTag; - using OperatorClass = typename ScheduleConfig::OperatorClass; - using TileShape = typename ScheduleConfig::TileShape; - using ClusterShape = typename ScheduleConfig::ClusterShape; - using KernelSchedule = typename ScheduleConfig::KernelSchedule; - using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; - using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; - using FusionOperation = typename ScheduleConfig::FusionOperation; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, @@ -179,10 +154,10 @@ struct GroupScaleGemm { FusionOperation >::CollectiveOp; - using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -191,38 +166,26 @@ struct GroupScaleGemm { KernelSchedule >::CollectiveOp; - using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler >; - using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler - >; - - using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter; - using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using GroupScale1D1DGemm = GroupScaleGemm; -using GroupScale1D2DGemm = GroupScaleGemm; -using GroupScale2D1DGemm = GroupScaleGemm; -using GroupScale2D2DGemm = GroupScaleGemm; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; // Extract information from Gemm kernel. -using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp; +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; using ElementAmax = typename EpilogueOutputOp::ElementAmax; using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; -using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA; -using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB; -using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC; -using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD; +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; using StrideAux = StrideD; constexpr bool IsDFp8 = @@ -242,20 +205,23 @@ StrideB stride_B; StrideC stride_C; StrideD stride_D; StrideAux stride_aux; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; uint64_t seed; +using LayoutScalar = cutlass::layout::PackedVectorLayout; + cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; cutlass::HostTensor tensor_D; uint32_t mma_promotion_interval; -cutlass::HostTensor blockscale_tensor_A; -cutlass::HostTensor blockscale_tensor_B; +cutlass::HostTensor blockscale_tensor_A; +cutlass::HostTensor blockscale_tensor_B; cutlass::HostTensor tensor_ref_D; cutlass::HostTensor tensor_aux; cutlass::HostTensor tensor_ref_aux; -using LayoutScalar = cutlass::layout::PackedVectorLayout; cutlass::HostTensor scalar_alpha; cutlass::HostTensor scalar_beta; cutlass::HostTensor scale_A; @@ -392,32 +358,25 @@ bool initialize_scale_tensor( } /// Initialize operands to be used in the GEMM and reference GEMM -template void initialize(const Options &options) { - using TileShape = typename GroupScaleConfig::TileShape; - const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; - assert(options.m % ScaleGranularityM == 0); assert(options.n % ScaleGranularityN == 0); - // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape - auto groupscale_m = ceil_div(options.m, ScaleGranularityM); - auto groupscale_n = ceil_div(options.n, ScaleGranularityN); - auto blockscale_k = ceil_div(options.k, cute::get<2>(TileShape{})); - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); stride_aux = stride_D; + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); - auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k); - auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k); + auto groupscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto groupscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); tensor_A.resize(a_coord); tensor_B.resize(b_coord); @@ -522,7 +481,9 @@ GemmArguments args_from_options(const Options &options) stride_B, mma_promotion_interval, blockscale_tensor_A.device_data(), - blockscale_tensor_B.device_data() + layout_SFA, + blockscale_tensor_B.device_data(), + layout_SFB }, { {}, // epilogue.thread @@ -572,19 +533,10 @@ GemmArguments args_from_options(const Options &options) } /// Don't know why the compiler does not like verify() being templated... -bool verify(const Options &options, const int ScaleMsPerTile, const int ScaleNsPerTile) { +bool verify(const Options &options) { // // Compute reference output // - const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile; - const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile; - - // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape - auto blockscale_m = ceil_div(options.m, get<0>(TileShape_{})); - auto blockscale_n = ceil_div(options.n, get<1>(TileShape_{})); - auto blockscale_k = ceil_div(options.k, get<2>(TileShape_{})); - auto groupscale_m = ceil_div(options.m, ScaleGranularityM); - auto groupscale_n = ceil_div(options.n, ScaleGranularityN); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), @@ -618,28 +570,18 @@ bool verify(const Options &options, const int ScaleMsPerTile ) ); - auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(), - cute::make_layout( - cute::make_shape(groupscale_m, blockscale_k, options.l), - cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), - cute::make_layout( - cute::make_shape(groupscale_n, blockscale_k, options.l), - cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA); + auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Groupwise scaling Tensors - }; + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -713,14 +655,7 @@ bool verify(const Options &options, const int ScaleMsPerTile } /// Execute a given example GEMM computation -template -int run(Options &options) -{ - using TileShape = typename GroupScaleConfig::TileShape; - const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; - const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; +int run(Options &options) { bool skip = false; std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; @@ -747,7 +682,7 @@ int run(Options &options) if (!skip) std::cout << " Running... " << std::endl; else return -1; - initialize(options); + initialize(options); // Instantiate CUTLASS kernel depending on templates Gemm gemm; @@ -773,7 +708,7 @@ int run(Options &options) // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; if (options.verify) { - result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile); + result.passed = verify(options); std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; } @@ -860,28 +795,7 @@ int main(int argc, char const **args) { #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) bool passed = true; - std::cout << "Basic split-K GEMM kernel" << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - - std::cout << std::endl; - - std::cout << "StreamK GEMM kernel" << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - passed &= run(options); - std::cout << std::endl; - + passed = run(options); if (!passed) return -1; #endif diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h deleted file mode 100644 index 8904060c..00000000 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_blockwise_scaling.h +++ /dev/null @@ -1,504 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Blockwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - int64_t block_m = m / kBlockM; - int64_t block_n = n / kBlockN; - cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, l); - cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for A and B - int64_t block_k = k / kBlockK; - ElementBlockScaleA scale_a = blockscale_A[block_k]; - ElementBlockScaleB scale_b = blockscale_B[block_k]; - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Blockwise-scaling at kBlockK boundary - // (a) Apply block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a * scale_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - ActivationFunctor activation; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Blockwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h deleted file mode 100644 index 0bf90a41..00000000 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ /dev/null @@ -1,518 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - const int M = cute::size<0>(mainloop_params.A.layout()); - const int N = cute::size<0>(mainloop_params.B.layout()); - const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA); - const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB); - assert(ScaleGranularityM && M % ScaleGranularityM == 0 - && "ScaleGranularityM must divide M"); - assert(ScaleGranularityN && N % ScaleGranularityN == 0 - && "ScaleGranularityN must divide N"); - - cute::Tensor blockscale_A = domain_offset( - make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); - cute::Tensor blockscale_B = domain_offset( - make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for B - int64_t block_k = k / kBlockK; - cute::Tensor scale_a = blockscale_A(_, block_k); - cute::Tensor scale_b = blockscale_B(_, block_k); - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - int m_size = std::min(static_cast(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m); - int n_size = std::min(static_cast(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n); - - // do compute - for (int m_b = 0; m_b < m_size; ++m_b) { - for (int n_b = 0; n_b < n_size; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Groupwise-scaling at kBlockK boundary - // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < m_size; ++m_b) { - auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; - for (int n_b = 0; n_b < n_size; ++n_b) { - auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - ActivationFunctor activation; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu index d20bad58..d14360de 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu @@ -87,11 +87,11 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" // Includes from examples directory #include "helper.h" #include "hopper_fp8_commandline.hpp" -#include "reference/host/gemm_with_groupwise_scaling.h" using namespace cute; @@ -128,54 +128,29 @@ using ElementAccumulator = float; // E using ElementBlockScale = float; // Element type for blockscaling during accumulation using ElementCompute = float; // Element type for epilogue computation -using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor -// Given TileShape = Shape<_128,_128,_128>: -// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) -// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling -// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling -// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling -template -struct GroupScaleConfig { - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size - using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; - static constexpr int ScaleGranularityM = ScaleGranularityM_; - static constexpr int ScaleGranularityN = ScaleGranularityN_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; +constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; - static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, - "FP8 scaling granularity must evenly divide tile shape along M."); - static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, - "FP8 scaling granularity must evenly divide tile shape along N."); +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using FusionOperation = cutlass::epilogue::fusion::LinearCombination; -}; +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand -using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; -using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; -using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; -using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::LinearCombination; -template -struct GroupScaleGemm { - using ArchTag = typename ScheduleConfig::ArchTag; - using OperatorClass = typename ScheduleConfig::OperatorClass; - using TileShape = typename ScheduleConfig::TileShape; - using ClusterShape = typename ScheduleConfig::ClusterShape; - using KernelSchedule = typename ScheduleConfig::KernelSchedule; - using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; - using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; - using FusionOperation = typename ScheduleConfig::FusionOperation; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, @@ -186,10 +161,10 @@ struct GroupScaleGemm { FusionOperation >::CollectiveOp; - using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< +using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA *, AlignmentA, - ElementB, LayoutB *, AlignmentB, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -198,29 +173,23 @@ struct GroupScaleGemm { KernelSchedule >::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloopWithGroupWiseScaling, - CollectiveEpilogue +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue >; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -using GroupScale1D1DGemm = GroupScaleGemm; -using GroupScale1D2DGemm = GroupScaleGemm; -using GroupScale2D1DGemm = GroupScaleGemm; -using GroupScale2D2DGemm = GroupScaleGemm; // Extract information from Gemm kernel. -using EpilogueOutputOp = typename GroupScale1D1DGemm::Gemm::EpilogueOutputOp; +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; -using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; -using StrideA = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideA; -using StrideB = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideB; -using StrideC = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideC; -using StrideD = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideD; +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; static_assert(cute::is_same_v, "ElementAccumulator and ElementBlockScale should be same datatype"); @@ -240,6 +209,8 @@ std::vector stride_A_host; std::vector stride_B_host; std::vector stride_C_host; std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; std::vector alpha_host; std::vector beta_host; @@ -265,6 +236,8 @@ cutlass::DeviceAllocation stride_A; cutlass::DeviceAllocation stride_B; cutlass::DeviceAllocation stride_C; cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; cutlass::DeviceAllocation alpha_device; cutlass::DeviceAllocation beta_device; @@ -343,10 +316,6 @@ bool initialize_block( template void allocate(const OptionType &options) { - using TileShape = typename OptionType::GroupScaleConfig::TileShape; - const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; - int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; @@ -372,10 +341,8 @@ void allocate(const OptionType &options) { auto N = get<1>(problem); auto K = get<2>(problem); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); - auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. - auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access. - auto blockscale_k = cute::get<2>(blockscale_shape); + auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B); @@ -388,8 +355,8 @@ void allocate(const OptionType &options) { int64_t elements_B = K * N; int64_t elements_C = M * N; int64_t elements_D = M * N; - int64_t elements_blockscale_A = groupscale_m * blockscale_k; - int64_t elements_blockscale_B = groupscale_n * blockscale_k; + int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA)); + int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB)); total_elements_A += elements_A; total_elements_B += elements_B; @@ -402,6 +369,8 @@ void allocate(const OptionType &options) { stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(group_layout_SFA); + layout_SFB_host.push_back(group_layout_SFB); } @@ -477,6 +446,12 @@ void initialize(const OptionType &options) { stride_D.reset(options.groups); stride_D.copy_from_host(stride_D_host.data()); + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + alpha_device.reset(options.groups); alpha_device.copy_from_host(ptr_alpha_host.data()); beta_device.reset(options.groups); @@ -500,14 +475,14 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha // Change device_id to another value if you are running on a machine with multiple GPUs and wish // to use a GPU other than that with device ID 0. int device_id = 0; - cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); GemmArguments arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), - ptr_blockscale_A.get(), - ptr_blockscale_B.get() + ptr_blockscale_A.get(), layout_SFA.get(), + ptr_blockscale_B.get(), layout_SFB.get() }, { {}, // epilogue.thread @@ -577,12 +552,6 @@ bool verify(const OptionType &options) { // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto [m, n, k] = options.problem_sizes_host.at(group_idx); auto gemm_problem_shape = cute::make_shape(m, n, k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); - auto groupscale_m = blockscale_m * OptionType::GroupScaleConfig::ScaleMsPerTile; - auto groupscale_n = blockscale_n * OptionType::GroupScaleConfig::ScaleNsPerTile; // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), @@ -610,32 +579,20 @@ bool verify(const OptionType &options) { ) ); - auto blockscale_A = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), - cute::make_layout( - cute::make_shape(groupscale_m, blockscale_k, 1), - cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) - ) - ); - auto blockscale_B = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), - cute::make_layout( - cute::make_shape(groupscale_n, blockscale_k, 1), - cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) - ) - ); + auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), + layout_SFA_host.at(group_idx)); + auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), + layout_SFB_host.at(group_idx)); using unused_t = decltype(D); - cutlass::reference::host::GettMainloopParams< + cutlass::reference::host::GettBlockScalingMainloopParams< ElementAccumulator, - decltype(A), + decltype(A), + decltype(SFA), decltype(B), - decltype(blockscale_A), - decltype(blockscale_B), - TileShape_ - > mainloop_params{ - A, B, // Operand Tensors - blockscale_A, blockscale_B // Groupwise scaling Tensors - }; + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -647,8 +604,7 @@ bool verify(const OptionType &options) { unused_t, // bias unused_t, // Aux unused_t, // valpha - unused_t, // vbeta - ActivationFunctor + unused_t // vbeta > epilogue_params; epilogue_params.C = C; @@ -679,15 +635,9 @@ bool verify(const OptionType &options) { } /// Execute a given example GEMM computation -template +template int run(OptionType &options, bool host_problem_shapes_available = true) { - using TileShape = typename OptionType::GroupScaleConfig::TileShape; - const int ScaleGranularityM = OptionType::GroupScaleConfig::ScaleGranularityM; - const int ScaleGranularityN = OptionType::GroupScaleConfig::ScaleGranularityN; - const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; - const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; - allocate(options); initialize(options); @@ -797,18 +747,12 @@ int main(int argc, char const **args) { // Parse options // - Options options_1d1d; - Options options_1d2d; - Options options_2d1d; - Options options_2d2d; + Options options; - options_1d1d.parse(argc, args); - options_1d2d.parse(argc, args); - options_2d1d.parse(argc, args); - options_2d2d.parse(argc, args); + options.parse(argc, args); - if (options_1d1d.help) { - options_1d1d.print_usage(std::cout) << std::endl; + if (options.help) { + options.print_usage(std::cout) << std::endl; return 0; } @@ -816,22 +760,10 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // - auto run_tests = [&] (bool host_problem_shapes_available = true) { - std::cout << "Grouped GEMM kernel with 1D1D group scale" << std::endl; - run(options_1d1d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 1D2D group scale" << std::endl; - run(options_1d2d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 2D1D group scale" << std::endl; - run(options_2d1d, host_problem_shapes_available); - std::cout << "Grouped GEMM kernel with 2D2D group scale" << std::endl; - run(options_2d2d, host_problem_shapes_available); - std::cout << std::endl; - }; - std::cout << "Running tests with host problem shapes:" << std::endl; - run_tests(true); + run(options, true); std::cout << "Running tests without host problem shapes:" << std::endl; - run_tests(false); + run(options, false); #endif diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu new file mode 100644 index 00000000..2ea42bbf --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu @@ -0,0 +1,781 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0. + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA + descriptors to move between groups/problem_count (represented by groups). + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + 4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + 5. This example is tuned specifically for the sparse groups case, where the number of active groups (groups + with non-zero problem count) is much smaller than the total number of groups. + Examples: + $ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups \ + --m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \ + --raster=h --swizzle=2 --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" + +// Includes from examples directory +#include "helper.h" +#include "hopper_fp8_commandline.hpp" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementBlockScale = float; // Element type for blockscaling during accumulation +using ElementCompute = float; // Element type for epilogue computation + +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +static constexpr int ScaleGranularityM = 1; +static constexpr int ScaleGranularityN = 128; +static constexpr int ScaleGranularityK = 128; +static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; +static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::LinearCombination; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule, + FusionOperation +>::CollectiveOp; + +using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule +>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + +/// Initialization + +cutlass::DeviceAllocation problem_sizes; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_blockscale_A; +std::vector offset_blockscale_B; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation blockscale_block_A; +cutlass::DeviceAllocation blockscale_block_B; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; +cutlass::DeviceAllocation ptr_blockscale_A; +cutlass::DeviceAllocation ptr_blockscale_B; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams>::RasterOrderOptions; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + double gbps; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + double gbps = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), gbps(gbps), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023, + ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) { + + double _scope_max, _scope_min; + int bits_input = cutlass::sizeof_bits::value; + if (bits_input == 1) { + _scope_max = 2; + _scope_min = 0; + } else if (bits_input <= 8) { + _scope_max = 2; + _scope_min = -2; + } else if (bits_input == 16) { + _scope_max = 5; + _scope_min = -5; + } else { + _scope_max = 8; + _scope_min = -8; + } + if constexpr (!std::is_same_v) { + _scope_max = scope_max; + } + if constexpr (!std::is_same_v) { + _scope_min = scope_min; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0); + + return true; +} + +/// Allocates device-side data +template +void allocate(const OptionType &options) { + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_blockscale_A = 0; + int64_t total_elements_blockscale_B = 0; + + offset_A.clear(); + offset_B.clear(); + offset_C.clear(); + offset_D.clear(); + offset_blockscale_A.clear(); + offset_blockscale_B.clear(); + stride_A_host.clear(); + stride_B_host.clear(); + stride_C_host.clear(); + stride_D_host.clear(); + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_after_alignment_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_blockscale_A.push_back(total_elements_blockscale_A); + offset_blockscale_B.push_back(total_elements_blockscale_B); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA)); + int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB)); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_blockscale_A += elements_blockscale_A; + total_elements_blockscale_B += elements_blockscale_B; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + layout_SFA_host.push_back(group_layout_SFA); + layout_SFB_host.push_back(group_layout_SFB); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + blockscale_block_A.reset(total_elements_blockscale_A); + blockscale_block_B.reset(total_elements_blockscale_B); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const OptionType &options) { + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_after_alignment_host.data()); + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + std::vector ptr_blockscale_A_host(options.groups); + std::vector ptr_blockscale_B_host(options.groups); + + alpha_host.clear(); + beta_host.clear(); + + for (int i = 0; i < options.groups; i++) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); + ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_blockscale_A.reset(options.groups); + ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data()); + + ptr_blockscale_B.reset(options.groups); + ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2022); + initialize_block(block_B, seed + 2023); + initialize_block(block_C, seed + 2024); + initialize_block(blockscale_block_A, seed + 2025, -1, 1); + initialize_block(blockscale_block_B, seed + 2026, -1, 1); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true) +{ + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + int device_id = 0; + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + GemmArguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_after_alignment_host.data() : (decltype(options.problem_sizes_after_alignment_host.data())) nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_blockscale_A.get(), layout_SFA.get(), + ptr_blockscale_B.get(), layout_SFB.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + kernel_hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +template +bool verify(const OptionType &options) { + + // + // Compute reference output + // + + std::vector block_A_host(block_A.size()); + std::vector block_B_host(block_B.size()); + std::vector block_C_host(block_C.size()); + std::vector block_D_host_kernel(block_D.size()); + std::vector block_D_host_ref(block_D.size()); + std::vector blockscale_block_A_host(blockscale_block_A.size()); + std::vector blockscale_block_B_host(blockscale_block_B.size()); + + block_A.copy_to_host(block_A_host.data()); + block_B.copy_to_host(block_B_host.data()); + block_C.copy_to_host(block_C_host.data()); + block_D.copy_to_host(block_D_host_kernel.data()); + blockscale_block_A.copy_to_host(blockscale_block_A_host.data()); + blockscale_block_B.copy_to_host(blockscale_block_B_host.data()); + + bool passed = true; + for (int group_idx = 0; group_idx < options.groups; group_idx++) { + // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape + auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx); + auto gemm_problem_shape = cute::make_shape(m, n, k); + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), + cute::make_layout( + cute::make_shape(m, k, 1), + stride_A_host.at(group_idx) + ) + ); + auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx), + cute::make_layout( + cute::make_shape(n, k, 1), + stride_B_host.at(group_idx) + ) + ); + auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_C_host.at(group_idx) + ) + ); + auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_D_host.at(group_idx) + ) + ); + + auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), + layout_SFA_host.at(group_idx)); + auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), + layout_SFB_host.at(group_idx)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha_host.at(group_idx); + epilogue_params.beta = beta_host.at(group_idx); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + auto this_group_passed = std::equal( + // std::execution::par_unseq, + block_D_host_ref.data() + offset_D.at(group_idx), + block_D_host_ref.data() + offset_D.at(group_idx) + m * n, + block_D_host_kernel.data() + offset_D.at(group_idx) + ); + + passed &= this_group_passed; + +#if 0 + std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl; +#endif + + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(OptionType &options, bool host_problem_shapes_available = true) +{ + + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + result.gbps = options.template gbps(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + + run(options, true); + + std::cout << "Running tests without host problem shapes:" << std::endl; + run(options, false); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt index f88b3167..09d506de 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt @@ -59,3 +59,26 @@ cutlass_example_add_executable( TEST_SMALL TEST_SMALL_LARGE_GROUP ) + +# MSVC will fail to compile this example with the following error: +# fatal error C1083: Cannot open source file: : No such file or directory [...\examples\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.vcxproj] +# This is a known issue and we are working on a fix. +if (NOT MSVC) + +cutlass_example_add_executable( + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + ) + +endif() diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index 3e425fe2..19497176 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ // Command line options parsing -template +template struct Options { using RasterOrderOptions = _RasterOrderOptions; using ProblemShape = _ProblemShape; - using GroupScaleConfig = _GroupScaleConfig; bool help = false; @@ -43,6 +42,7 @@ struct Options { int iterations = 1000; int m = 1024, n = 512, k = 1024, groups = 10; std::string benchmark_path; + std::vector problem_sizes_after_alignment_host; std::vector problem_sizes_host; int const tma_alignment_bits = 128; int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -89,6 +89,7 @@ struct Options { // Decide how to initialize the problems if (!benchmark_path.empty()) { if (!benchmark_problems()) { + problem_sizes_after_alignment_host.clear(); problem_sizes_host.clear(); return; } @@ -105,8 +106,8 @@ struct Options { cmd.get_cmd_line_argument("n", cmd_line_n); cmd.get_cmd_line_argument("k", cmd_line_k); + problem_sizes_after_alignment_host.reserve(groups); problem_sizes_host.reserve(groups); - for (int i = groups; i > 0; i--) { int m = cmd_line_m; int n = cmd_line_n; @@ -120,6 +121,7 @@ struct Options { if (k < 1) { k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1); } + problem_sizes_after_alignment_host.push_back({m, n, k}); problem_sizes_host.push_back({m, n, k}); } } @@ -142,7 +144,7 @@ struct Options { break; } - cutlass::gemm::GemmCoord extent; + cutlass::gemm::GemmCoord extent_after_alignment, extent; std::vector tokens; cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); @@ -150,23 +152,81 @@ struct Options { for (int i = 0; i < int(tokens.size()); ++i) { int x = std::atoi(tokens.at(i).c_str()); + extent.at(i) = x; // round up if (x % alignment) { x += (alignment - (x % alignment)); } - extent.at(i) = x; + extent_after_alignment.at(i) = x; } - if (extent.product()) { - problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); - } + problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()}); + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); } - groups = static_cast(problem_sizes_host.size()); + groups = static_cast(problem_sizes_after_alignment_host.size()); return true; } + /// Calculate memory bandwidth statistics + template + auto gbps(double runtime_s) const { + double total_read_bytes = 0; + double total_write_bytes = 0; + + // Calculate bytes read and written for each problem + for (int i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = cute::get<0>(problem); + auto N = cute::get<1>(problem); + auto K = cute::get<2>(problem); + + if (M > 0) { // Only count active problems + // Matrix A: M*K elements read + total_read_bytes += M * K * sizeof(ElementA); + + // Matrix B: K*N elements read + total_read_bytes += K * N * sizeof(ElementB); + + // Matrix C: M*N elements read (for beta operation) + total_read_bytes += M * N * sizeof(ElementC); + + // Block scales for A and B + auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = blockscale_m * ScaleMsPerTile; + auto groupscale_n = blockscale_n * ScaleNsPerTile; + + total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales + total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales + + // Matrix D: M*N elements written + total_write_bytes += M * N * sizeof(ElementD); + } + } + + return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s; + } + + double bandwidth_util(double eff_bandwidth) const { + int memoryClockRate; + int memoryBusWidth; + cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0); + cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0); + double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6; + return eff_bandwidth / bw * 100.0; + } + /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h deleted file mode 100644 index 1a94af67..00000000 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ /dev/null @@ -1,520 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Reference implementation for GETT in host-side code. -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/gemm.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/relatively_equal.h" -#include -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::reference::host { - -template -struct ElementTraits { - using type = T; -}; - -template -struct ElementTraits().get()), void> > > { - using type = decltype(std::declval().get()); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class ElementAccumulator_, - class TensorA_, // (M, K, L) - class TensorB_, // (N, K, L) - class TensorScaleA_, // (m, k, L) - class TensorScaleB_, // (n, k, L) - class TileShape_ -> -struct GettMainloopParams { - using ElementAccumulator = ElementAccumulator_; - using TensorA = TensorA_; - using TensorB = TensorB_; - using EngineA = typename TensorA::engine_type; - using LayoutA = typename TensorA::layout_type; - using EngineB = typename TensorB::engine_type; - using LayoutB = typename TensorB::layout_type; - - using TensorScaleA = TensorScaleA_; - using TensorScaleB = TensorScaleB_; - using TileShape = TileShape_; - using EngineScaleA = typename TensorScaleA::engine_type; - using EngineScaleB = typename TensorScaleB::engine_type; - - TensorA A{}; - TensorB B{}; - TensorScaleA ScaleA{}; - TensorScaleB ScaleB{}; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template< - class ElementScalar_, - class ElementScalingFactor_, - class ElementAccumulator_, - class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) - class ActivationFunctor_ = cutlass::epilogue::thread::Identity, - class BiasBinaryOp_ = cutlass::plus, - bool PerColumnBias_ = false -> -struct GettEpilogueParams { - using ElementScalar = ElementScalar_; - using ElementScalingFactor = ElementScalingFactor_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using TensorC = TensorC_; - using TensorD = TensorD_; - using TensorAux = TensorAux_; - using VectorBias = VectorBias_; - using VectorAlpha = VectorAlpha_; - using VectorBeta = VectorBeta_; - using ActivationFunctor = ActivationFunctor_; - using BiasBinaryOp = BiasBinaryOp_; - - using EngineC = typename TensorC::engine_type; - using LayoutC = typename TensorC::layout_type; - using EngineD = typename TensorD::engine_type; - using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - - TensorC C{}; - TensorD D{}; - VectorBias Bias{}; - TensorAux Aux{}; - VectorAlpha Valpha{}; - VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); - - ElementAccumulator* abs_max_D = nullptr; - ElementAccumulator* abs_max_Aux = nullptr; - - ElementScalingFactor scale_a = ElementScalingFactor(1); - ElementScalingFactor scale_b = ElementScalingFactor(1); - ElementScalingFactor scale_c = ElementScalingFactor(1); - ElementScalingFactor scale_d = ElementScalingFactor(1); - ElementScalingFactor scale_aux = ElementScalingFactor(1); - - bool beta_per_channel_scaling = false; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling -template < - class MainloopParams, - class EpilogueParams -> -void Gett( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - - static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); - static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); - // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); - // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); - -#if defined(_OPENMP) - #pragma omp parallel for collapse(3) -#endif - for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { - for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { - for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { - typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; - gett_mainloop(mainloop_params, m, n, l, acc); - gett_epilogue(epilogue_params, m, n, l, acc); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Mainloop -template -void gett_mainloop( - MainloopParams const& mainloop_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); - static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementA = typename ElementTraits::type; - using ElementB = typename ElementTraits::type; - using ElementBlockScaleA = typename ElementTraits::type; - using ElementBlockScaleB = typename ElementTraits::type; - - using RingOp = multiply_add; - RingOp fma_op; - - multiplies scale_op; - - static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; - - // Tempo accumulators to seperate blockwise accumulation - typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; - - // Zero out accumulators - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - - const int M = cute::size<0>(mainloop_params.A.layout()); - const int N = cute::size<0>(mainloop_params.B.layout()); - - const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA.layout()); - const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB.layout()); - - assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M"); - assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N"); - - cute::Tensor blockscale_A = domain_offset(make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); - cute::Tensor blockscale_B = domain_offset(make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); - - // Compute on this k-block - for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { - - // Load Blockwise scaling factor from blockscale Tensors for B - int64_t block_k = k / kBlockK; - cute::Tensor scale_a = blockscale_A(_, block_k); - cute::Tensor scale_b = blockscale_B(_, block_k); - - // Load A - ElementAccumulator a_frag[kBlockM]; - for (int m_b = 0; m_b < kBlockM; ++m_b) { - if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); - } else { - a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // Load B - ElementAccumulator b_frag[kBlockN]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. - b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); - } else { - b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity - } - } - - // do compute - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); - } - } - - // Apply Groupwise-scaling at kBlockK boundary - // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary - // (b) Zero-out partial temporary (acc_temp), - // (c) Update permanent (accu) - if ((k+1) % kBlockK == 0) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { - auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; - for (int n_b = 0; n_b < kBlockN; ++n_b) { - auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; - acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; - acc_temp[m_b][n_b] = ElementAccumulator(0); - } - } - } - - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GETT - Epilogue -template -void gett_epilogue( - EpilogueParams const& epilogue_params, - int64_t m, - int64_t n, - int64_t l, - ElementAccumulator (&acc)[kBlockM][kBlockN]) -{ - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); - static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); - - using cute::raw_pointer_cast; - - using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::TensorC::value_type; - using ElementD = typename EpilogueParams::TensorD::value_type; - using ElementAux = typename EpilogueParams::TensorAux::value_type; - using ElementBias = typename EpilogueParams::VectorBias::value_type; - using ElementScalar = typename EpilogueParams::ElementScalar; - using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; - using ActivationFunctor = typename EpilogueParams::ActivationFunctor; - using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; - - constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsScalingAndAmaxAuxOutputNeeded = - cute::is_same_v or - cute::is_same_v; - - constexpr bool IsReLUAuxNeeded = - (cute::is_same_v> or - cute::is_same_v>) and - cute::is_same_v; - constexpr bool IsClamp = - cute::is_same_v>; - - constexpr bool IsBackpropFusion = - cute::is_same_v> or - cute::is_same_v>; - - // Input related converter - NumericConverter accumulator_converter; - NumericConverter source_converter; - NumericConverter bias_converter; - [[maybe_unused]] NumericConverter aux_source_converter; - - // Scale related converter - NumericConverter scale_converter; - NumericConverter scaling_factor_converter; - - // Abs max converter - [[maybe_unused]] NumericConverter abs_max_output_converter; - - // Output related converter - NumericConverter destination_converter; - [[maybe_unused]] NumericConverter aux_destination_converter; - NumericConverter dBias_converter; - - // Epilogue operations - multiply_add epilogue_fma; - multiplies mul; - plus add; - - // Activation operation - - auto activation = [] (ElementCompute x, ElementCompute y = ElementCompute(0)) { - if constexpr (std::is_same_v) { - return x + y; - } else { - return ActivationFunctor()(x, y); - } - }; - - // Bias binary operation - BiasBinaryOp bias_op; - - // Do conversion - ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); - ElementCompute converted_beta = scale_converter(epilogue_params.beta); - ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); - ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); - ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); - ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); - ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); - - // Init local var - [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); - [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); - - converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); - converted_beta = mul(converted_beta, converted_scale_c); - - ElementCompute inter_accum[kBlockM][kBlockN]; - - for (int m_b = 0; m_b < kBlockM; ++m_b) { - ElementCompute local_dBias = ElementCompute(0); - - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - // Convert every type to ElementCompute first, do compute, convert to output type, write it out - ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha - if (raw_pointer_cast(epilogue_params.Valpha.data())) { - converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); - } - ElementCompute output = mul(converted_alpha, converted_acc); - - if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { - ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); - output = bias_op(output, converted_bias); - } - - if (raw_pointer_cast(epilogue_params.C.data())) { - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta - if (epilogue_params.Vbeta.data()) { - converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); - } - output = epilogue_fma(converted_beta, converted_src, output); - } - - if constexpr (IsBackpropFusion) { - ElementAux aux_input = ElementAux(0); - if (raw_pointer_cast(epilogue_params.Aux.data())) { - aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); - } - - output = activation(output, aux_source_converter(aux_input)); - local_dBias = add(local_dBias, output); - } - else { - if (raw_pointer_cast(epilogue_params.Aux.data())) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); - } - - if constexpr (IsReLUAuxNeeded) { - epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); - } else { - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); - } - } - - if constexpr (IsClamp) { // Treat Clamp as ReLU - output = activation(output, {0, std::numeric_limits::max()}); - } - else { - output = activation(output); - } - } - - if constexpr (IsScalingAndAmaxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_output = amax_op(local_abs_max_output, output); - output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); - } - - inter_accum[m_b][n_b] = ElementCompute(output); - } - } // n_b - - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { - if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { - ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); - local_dBias = add(local_dBias, converted_dBias); - epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); - } - } - } // m_b - for (int m_b = 0; m_b < kBlockM; ++m_b) { - for (int n_b = 0; n_b < kBlockN; ++n_b) { - if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { - epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); - } - } - } - -#if defined(_OPENMP) - #pragma omp critical(Abs_Max_Data_Update) -#endif - { - if constexpr (IsScalingAndAmaxOutputNeeded) { - if (epilogue_params.abs_max_D) { - *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); - } - } - - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - if (epilogue_params.abs_max_Aux) { - *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( - *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); - } - } - } -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM - General Matrix-Matrix contraction without conjugation options -template < - class MainloopParams, - class EpilogueParams -> -void Gemm3x( - MainloopParams const& mainloop_params, - EpilogueParams const& epilogue_params) -{ - using namespace cute; - - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); - static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " - "with Batchmode are supported"); - // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). - Gett(mainloop_params, epilogue_params); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // cutlass::reference::host - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index 75d3437d..8be4f639 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -480,7 +480,12 @@ bool verify(const Options &options) { passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); - return passed; + block_SFD.sync_host(); + bool passed_sfd = cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view()); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0); + passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0); + + return passed && passed_sfd; } /// Execute a given example GEMM computation diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 1d1314d1..c8792122 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -67,9 +67,6 @@ --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 */ -#define DSHOW(x) print(#x ": "); print(x); print("\n"); -#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n"); - #include #include #include @@ -247,8 +244,8 @@ struct Options { << " and are split B-ways, alternatingly +10% and -10%\n" << " with the last batch sized to make it fit\n" << " implies at least residual masking for correctness\n" - << " --sm-count Sets SM count rather than querying it\n" - << " --kernel-filter= Sets regexp to match kernel against\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" << "\n"; return out; diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu new file mode 100644 index 00000000..1c02a29e --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu @@ -0,0 +1,865 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3. + + This example showcases the use of CUTLASS to build backward fused + multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting + the NVIDIA Blackwell architecture. + + Background and motivation + ------------------------- + CUTLASS is a highly flexible library that provides open-source building blocks + for tensor core programming for GEMM or GEMM-like problems. Fused multi-head + attention (FMHA) is a foundational kernel for large language models (LLMs) since it + makes long sequence lengths feasible from a memory-usage perspective. It also + improves computational efficiency since it transforms an outer-product-like and + a matrix-vector-like GEMM into a fused operation with much higher arithmetic + intensity. For more details, see Dao et al, 2022; Dao, 2023. + Implementing this kernel in CUTLASS enabled easy customization and high + performance. + + Introduction + ------------ + The example targets the NVIDIA Blackwell architecture, and takes advantage of + 5th gen tensor cores and the Tensor Memory Accelerator (TMA), just like + GEMMs do. It provides a backward pass (often abbreviated + bwd in the code). + The code is structured into three layers: The runner (and the reference kernels) + takes care of initialization, measurement, and testing; the device layer + orchestrates kernel calls and partitions workspace; and the kernel layer (just + like the CUTLASS kernel layer. + + Support + --------- + + We support fp16 and fp8 data types with a head dimension of 128. + + Example usage: + $ ./examples/77_blackwell_fmha/77_blackwell_fmha_bwd_fp16 \ + --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 +*/ + +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "reference/fmha_fwd_reference.hpp" +#include "reference/fmha_bwd_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kOne, kZero, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 16; + int h = 16; + int h_k = 1; + int q = 1024; + int k = 1024; + int d = 128; + int iterations = 3; + bool verify = false; + bool verbose = false; + + bool causal = false; + int sm_count = 0; + + std::string kernel_filter; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_k = InitStyle::kRandom; + InitStyle init_style_v = InitStyle::kRandom; + InitStyle init_style_do = InitStyle::kRandom; + bool skip_reference = false; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "0") { + dst = InitStyle::kZero; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("h", h, -1); + if (h == -1) h = 2048 / d; + + cmd.get_cmd_line_argument("q", q, -1); + cmd.get_cmd_line_argument("k", k, -1); + if (q == -1) q = k; + if (k == -1) k = q; + if (q == -1 && k == -1) q = k = defaults.q; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + std::string mask; + cmd.get_cmd_line_argument("mask", mask, ""); + if (mask == "causal") { + causal = true; + } + else { + causal = defaults.causal; + } + + skip_reference = cmd.check_cmd_line_flag("skip-reference"); + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_k); + get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_v); + get_init_style_argument(cmd, "init-style", init_style_do, defaults.init_style_do); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k); + get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v); + get_init_style_argument(cmd, "init-style-do", init_style_v, init_style_do); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_fmha_bwd\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " fused multi-head attention kernels for the backward pass targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --h= Sets the H extent\n" + << " --q= Sets the Q extent\n" + << " --k= Sets the K extent\n" + << " --d= Sets the D extentn" + << " --iterations= Benchmarking iterations\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --mask= Enables masking\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kZero: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 0, (Element) 0); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) 0, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (i % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class DispatchPolicy, + class ActiveMask, + class... KernelOptions +> +struct BwdRunner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; +#else + using Element = cutlass::half_t; +#endif + using ElementAccumulator = float; + + // Q K D (H B) + using ProblemShapeType = cute::tuple>; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + + using TensorStride = Stride>; // Seq D (H B) + using StrideQ = TensorStride; + using StrideK = TensorStride; + using StrideV = TensorStride; + using StrideO = TensorStride; + using StrideLSE = Stride<_1, Stride>; // Seq (H B) + + // Backwards specific + using StrideDQ = TensorStride; + using StrideDK = TensorStride; + using StrideDV = TensorStride; + using StrideDO = TensorStride; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + StrideDQ stride_dQ; + StrideDK stride_dK; + StrideDV stride_dV; + StrideDO stride_dO; + + uint64_t seed = 0; + + DeviceAllocation block_Q; + DeviceAllocation block_K; + DeviceAllocation block_V; + DeviceAllocation block_O; + DeviceAllocation block_LSE; + + DeviceAllocation block_dQ; + DeviceAllocation block_dK; + DeviceAllocation block_dV; + DeviceAllocation block_dO; + + DeviceAllocation block_ref_dQ; + DeviceAllocation block_ref_dK; + DeviceAllocation block_ref_dV; + + // + // Methods + // + bool verify(const ProblemShapeType& problem_shape) { + auto [Q, K, D, HB] = problem_shape; + auto [H, B] = HB; + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + select<0,2,3>(problem_shape), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + select<1,2,3>(problem_shape), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + select<1,2,3>(problem_shape), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + select<0,2,3>(problem_shape), + stride_O); + + // keep going here! (this might be better in cursor) + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + select<0,3>(problem_shape), + stride_LSE); + + Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), + select<0,2,3>(problem_shape), + stride_dQ); + + Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), + select<1,2,3>(problem_shape), + stride_dK); + + Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), + select<1,2,3>(problem_shape), + stride_dV); + + Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), + select<0,2,3>(problem_shape), + stride_dO); + + fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{}); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-0 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff); + + bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dQ) { + std::cerr << "failed dQ: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff); + + bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dK) { + std::cerr << "failed dK: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff); + + bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_dV) { + std::cerr << "failed dV: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + return passed_dQ && passed_dK && passed_dV; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_shape, Options const& options) { + auto [Q, K, D, HB] = problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + + auto shape_QO = select<0,2,3>(problem_shape); + auto shape_KV = select<1,2,3>(problem_shape); + auto shape_LSE = select<0,3>(problem_shape); + + stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H)); + stride_V = stride_K; + stride_O = stride_Q; + stride_LSE = make_stride(_1{}, make_stride(Q, Q*H)); + + stride_dQ = stride_Q; + stride_dK = stride_K; + stride_dV = stride_V; + stride_dO = stride_O; + + auto lsize = [](auto shape) { + return size(make_shape(1ull, shape)); + }; + + block_Q.reset(lsize(shape_QO)); + block_K.reset(lsize(shape_KV)); + block_V.reset(lsize(shape_KV)); + block_O.reset(lsize(shape_QO)); + block_LSE.reset(lsize(shape_LSE)); + + block_dQ.reset(lsize(shape_QO)); + block_dK.reset(lsize(shape_KV)); + block_dV.reset(lsize(shape_KV)); + block_dO.reset(lsize(shape_QO)); + + block_ref_dQ.reset(lsize(shape_QO)); + block_ref_dK.reset(lsize(shape_KV)); + block_ref_dV.reset(lsize(shape_KV)); + + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_K, seed + 2022, options.init_style_k); + initialize_block(block_V, seed + 2021, options.init_style_v); + initialize_block(block_dO, seed + 2020, options.init_style_do); + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + select<0,2,3>(problem_shape), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + select<1,2,3>(problem_shape), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + select<1,2,3>(problem_shape), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + select<0,2,3>(problem_shape), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + select<0,3>(problem_shape), + stride_LSE); + + if (! options.skip_reference) { + fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + } + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b)); + + initialize(problem_shape, options); + + ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d); + + typename Operation::Arguments arguments{ + problem_shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O, + block_LSE.get(), stride_LSE, + block_dO.get(), stride_dO, + block_dQ.get(), stride_dQ, + block_dK.get(), stride_dK, + block_dV.get(), stride_dV, + softmax_scale, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + flops *= static_cast(get<0>(problem_shape)); + flops *= static_cast(get<1>(problem_shape)); + flops *= static_cast(get<2>(problem_shape)); + flops *= static_cast(get<3,0>(problem_shape)); + flops *= static_cast(get<3,1>(problem_shape)); + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_tc_s = tflops_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms << "ms, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct KernelCoop {}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _64; + + run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _128; + + run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability 100a) and CUDA 12.8 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " "; + std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + auto with_causal = [&](auto fn) { + if (options.causal) { + fn(CausalMask{}); + } + else { + fn(NoMask{}); + } + }; + + with_causal([&](auto fusion) { + if (options.d <= 64) { + run_bwd_64(fusion, options, hw_info); + } + else if (options.d <= 128) { + run_bwd_128(fusion, options, hw_info); + } + else { + std::cout << "No kernel instantiated for d=" << options.d << std::endl; + } + }); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index bff609fa..f04ebe41 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -28,16 +28,14 @@ set_property( - SOURCE 77_blackwell_fmha.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") - -set_property( - SOURCE 77_blackwell_fmha_gen.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") - -set_property( - SOURCE 77_blackwell_mla.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") + SOURCE + 77_blackwell_fmha.cu + 77_blackwell_fmha_gen.cu + 77_blackwell_mla.cu + 77_blackwell_fmha_bwd.cu + PROPERTY + COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0" +) set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) @@ -116,5 +114,34 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B) target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v) + cutlass_example_add_executable( + 77_blackwell_fmha_bwd_${PREC} + 77_blackwell_fmha_bwd.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) + target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_fmha_bwd_sat_${PREC} + 77_blackwell_fmha_bwd.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_GEN_VARLEN + TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) + target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) endforeach() endif() diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index c8250a7d..a1536dc8 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -22,6 +22,21 @@ The `apply_mask` function is called with the accumulator of the first GEMM and t It is well-suited for applying masks or activations. More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA. +# FMHA for Blackwell: Backward + +This sample provides code for fused multi-head attention backward pass. +It supports HeadDims of 64 and 128, and fp8, fp16, and bf16 input data types. +The blocking in sequence length Q and K is 128, loads are done via TMA. +We support causal masking. +The structure of this code is very similar to the forward pass, and the techniques are analogous. + +There are three kernels to compute backwards: +1. `FmhaKernelBwdSumOdO` to compute the sum of the outer product of O and dO. +3. `Sm100FmhaBwdKernelTmaWarpSpecialized` to compute the backward pass. +2. `FmhaKernelBwdConvert` to convert the dQ from fp32 to the final output precision. + +`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel. + # MLA Inference for Blackwell This sample provides code for fused multi-head latent attention inference in diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 00000000..80fcdf9f --- /dev/null +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class Sm100FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + // Q K D HB + cute::tuple> problem_size; + + const Element* ptr_Q; + cute::tuple> stride_Q; + const Element* ptr_K; + cute::tuple> stride_K; + const Element* ptr_V; + cute::tuple> stride_V; + + const Element* ptr_O; + cute::tuple> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple> stride_LSE; + + const Element* ptr_dO; + cute::tuple> stride_dO; + + Element* ptr_dQ; + cute::tuple> stride_dQ; + Element* ptr_dK; + cute::tuple> stride_dK; + Element* ptr_dV; + cute::tuple> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using Operation = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized + >; + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_size, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + return typename OperationConvert::Arguments { + args.problem_size, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + return typename Operation::Arguments{ + args.problem_size, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + scaled_lse, stride_scaled_lse, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ, + args.softmax_scale }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q, K, D, HB] = args.problem_size; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 00000000..c2618bcb --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + tuple> problem_size; + + const ElementAcc* ptr_src_dQ; + tuple> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple> stride_src_dV; + + Element* ptr_dest_dQ; + tuple> stride_dest_dQ; + Element* ptr_dest_dK; + tuple> stride_dest_dK; + Element* ptr_dest_dV; + tuple> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= count) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 00000000..44080e2d --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + cute::tuple> problem_size; + + const Element* ptr_O; + cute::tuple> stride_O; + const Element* ptr_dO; + cute::tuple> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= get<0>(params.problem_size)) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 00000000..e1bd43d5 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1699 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using ProblemShape = Shape>; // Q K D (H B), eventuall D = (D_QK, D_VO) + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q, K, D, HB] = args.problem_shape; + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mQ = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mV = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D, HB)); + auto mDO = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D, HB)); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + cutlass::arch::cp_async<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + cutlass::arch::cp_async<16>( + shared_tensors.smem_lse.begin() + smem_idx, + &mLSE(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + cutlass::arch::cp_async_zfill<16>( + shared_tensors.smem_sum_odo.begin() + smem_idx, + &mSumOdO(gmem_idx, blk_coord_batch), + gmem_idx < Q + ); + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = SM100_TMEM_STORE_32dp32b8x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST = split_wg(thread_r2t.partition_S(tDVcST)); + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + dispatch_bool(std::is_base_of_v && + warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (std::is_base_of_v && decltype(is_causal_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, make_coord(blockIdx.y, blockIdx.z)); + auto problem_shape = params.problem_shape; + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } + iter_count -= iter_start; + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 00000000..bb8cfb34 --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_K + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + } + mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_Q + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + } + mDK(idx_K, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + Fusion fusion) { + + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L); + ElementAcc rK = mK(idx_K, idx_D0, idx_L); + acc_qk += rQ * rK; + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + } // for idx_Q + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + ElementAcc rS = mS[idx_Q]; + ElementAcc rDO = mDO(idx_Q, idx_D, idx_L); + acc += rS * rDO; + } + mDV(idx_K, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDK), size<2>(mDK), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDV), size<2>(mDV), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion) { + + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index 48d81101..b7c6b412 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -128,7 +128,7 @@ void __global__ fmha_reference_kernel( } if (threadIdx.x == 0) { - mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS; + mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; } } diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu new file mode 100644 index 00000000..d36bf4dd --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Grouped GEMM example using CUTLASS 3x APIs for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel + for narrow precisions (FP4) with input Scale Factors. + For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = float_e2m1_t; // Element type for D matrix operands +using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal computation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Epilogue Operator class tag + +// Kernel Perf config +// Cluster Shape fixed to 1x1x1 +using ThreadBlockShape = Shape<_128,_128,_128>; +using ClusterShape = Shape<_1,_1,_1>; +constexpr int OutputSFVectorSize = 16; + +// D = alpha * acc + beta * C +// With BlockScaleFactor generation. +using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutCTag, + ElementC>; + +// Cooperative kernel schedule +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag *, AlignmentC, + ElementD, LayoutCTag *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation +>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag *, AlignmentA, + ElementB, LayoutBTag *, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Auto schedule defaults to cooperative schedule +>::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + +// Pingpong kernel schedule +using CollectiveMainloopPingpong = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag *, AlignmentA, + ElementB, LayoutBTag *, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong +>::CollectiveOp; + +using GemmKernelPingpong = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopPingpong, + CollectiveEpilogue +>; + +using GemmPingpong = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + OutputSFVectorSize, + cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN + >; +using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; + +// Host-side allocations +std::vector stride_A_host; +std::vector stride_B_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorSF = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; +std::vector block_A; +std::vector block_B; +std::vector block_SFA; +std::vector block_SFB; +std::vector block_C; +std::vector block_D; +std::vector block_SFD; +std::vector block_ref_D; +std::vector block_ref_SFD; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFD; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +// A matrix wide constant value to scale the output matrix +// Avoids generating small FP4 values. +// NormConst is a single device-side constant value, its not per-batch or per-group +cutlass::DeviceAllocation norm_constant_device; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool verification = true; + bool use_pdl = false; + + float alpha = std::numeric_limits::max(); + float beta = std::numeric_limits::max(); + float norm_constant = 1.0; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, std::numeric_limits::max()); + cmd.get_cmd_line_argument("beta", beta, std::numeric_limits::max()); + cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0)); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "79d_blackwell_geforce_nvfp4_grouped_gemm\n\n" + << " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --norm_constant= Epilogue scalar normalization constant for the output matrix\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --no_verif Do not run (host-side) verification kernels\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "79d_blackwell_geforce_nvfp4_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + stride_A_host.push_back(stride_A); + stride_B_host.push_back(stride_B); + layout_SFA_host.push_back(layout_SFA); + layout_SFB_host.push_back(layout_SFB); + stride_C_host.push_back(stride_C); + stride_D_host.push_back(stride_D); + + block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A)))); + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA))))); + block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_ref_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + } + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + uint64_t seed = 2020; + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_SFA_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_SFD_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + + initialize_block(block_A.at(i).host_view(), seed + 2021); + initialize_block(block_B.at(i).host_view(), seed + 2022); + initialize_block(block_C.at(i).host_view(), seed + 2023); + initialize_block(block_SFA.at(i).host_view(), seed + 2024); + initialize_block(block_SFB.at(i).host_view(), seed + 2025); + + block_A.at(i).sync_device(); + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFA.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_A_host.at(i) = block_A.at(i).device_data(); + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_SFA_host.at(i) = block_SFA.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFD_host.at(i) = block_SFD.at(i).device_data(); + + alpha_host.push_back((options.alpha == std::numeric_limits::max()) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == std::numeric_limits::max()) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_SFD.reset(options.groups); + ptr_SFD.copy_from_host(ptr_SFD_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + norm_constant_device.reset(1); + norm_constant_device.copy_from_host(&options.norm_constant); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + if (options.alpha != std::numeric_limits::max()){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != std::numeric_limits::max()) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + // Output Block SF + fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output + fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB); + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); + auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); + auto tensor_ref_SFD = cute::make_tensor(make_iterator(block_ref_SFD.at(i).host_data()), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_ref_D), // TensorD + decltype(tensor_ref_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params {alpha_host.at(i), beta_host.at(i), tensor_C, tensor_ref_D, tensor_ref_SFD, options.norm_constant}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.at(i).sync_host(); + block_SFD.at(i).sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + passed &= cutlass::reference::host::TensorEquals(block_ref_SFD.at(i).host_view(), block_SFD.at(i).host_view()); + // Check that the tensors have non-zero norms + passed &= (cutlass::reference::host::TensorNorm(block_ref_D.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_ref_SFD.at(i).host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_SFD.at(i).host_view()) > 0); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + if (options.verification) { + std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + } + else { + std::cout << " Verfication is turned off for this run." << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) + ) + ) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 12 && props.minor == 0)) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with Cooperative kernel schedule:" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "Running kernel with Pingpong kernel schedule:" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/79_blackwell_geforce_gemm/CMakeLists.txt b/examples/79_blackwell_geforce_gemm/CMakeLists.txt index cb7e3e97..b689c85e 100644 --- a/examples/79_blackwell_geforce_gemm/CMakeLists.txt +++ b/examples/79_blackwell_geforce_gemm/CMakeLists.txt @@ -28,6 +28,24 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes + if (CUTLASS_NVCC_ARCHS MATCHES 120a) cutlass_example_add_executable( 79a_blackwell_geforce_nvfp4_bf16_gemm @@ -44,4 +62,22 @@ cutlass_example_add_executable( 79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu ) +cutlass_example_add_executable( + 79d_blackwell_geforce_nvfp4_grouped_gemm + 79d_blackwell_geforce_nvfp4_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP +) + endif() diff --git a/examples/cute/tutorial/blackwell/01_mma_sm100.cu b/examples/cute/tutorial/blackwell/01_mma_sm100.cu index 3f73140a..a11fb17c 100644 --- a/examples/cute/tutorial/blackwell/01_mma_sm100.cu +++ b/examples/cute/tutorial/blackwell/01_mma_sm100.cu @@ -61,7 +61,8 @@ #include // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -122,7 +123,9 @@ struct SharedStorage alignas(128) cute::ArrayEngine> A; alignas(128) cute::ArrayEngine> B; - alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } @@ -225,6 +228,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -233,10 +248,8 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) } __syncthreads(); - // Barrier Initialization - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + // Barrier Initialization // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); @@ -306,6 +319,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -124,6 +125,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -228,6 +231,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -269,9 +284,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); - // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); @@ -346,6 +358,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -129,6 +130,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -231,6 +234,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -305,10 +320,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - - uint32_t elect_one_thr = cute::elect_one_sync(); - uint32_t elect_one_warp = (threadIdx.x / 32 == 0); - // Barriers in SMEM initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -385,6 +396,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -132,6 +133,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; @@ -234,6 +237,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator2Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -262,6 +277,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Construct the CTA-in-Cluster coordinate for multicasting auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; // Project the cluster_layout for tma_A along the N-modes auto [tAgA, tAsA] = tma_partition(tma_atom_A, @@ -299,10 +315,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - auto elect_one_thr = cute::elect_one_sync(); - auto elect_one_warp = (threadIdx.x / 32 == 0); - auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; - // Barriers in SMEM should be initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -386,6 +398,15 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) axpby(alpha, tDrAcc, beta, tDrC); // Store RMEM -> GMEM copy(tDrC, tDgD); + + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template // CuTe tensor implementation #include // CuTe functions for querying the details of cluster launched #include // Compile time in constants such as _1, _256 etc. -#include +#include // Auto vectorized copy operation +#include // TMEM allocator for SM100 // Tutorial helpers #include "example_utils.hpp" @@ -140,6 +141,8 @@ struct SharedStorage alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(tensors.mainloop.A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(tensors.mainloop.B.begin()), BSmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sC() { return make_tensor(make_smem_ptr(tensors.C.begin()), CSmemLayout{}); } @@ -247,6 +250,18 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + using TmemAllocator = cute::TMEM::Allocator2Sm; + TmemAllocator tmem_allocator{}; + + if (elect_one_warp) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } + __syncthreads(); // Wait for all threads until warp0 allocates TMEM + tCtAcc.data() = shared_storage.tmem_base_ptr; + if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) @@ -275,6 +290,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Construct the CTA-in-Cluster coordinate for multicasting auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; // Project the cluster_layout for tma_A along the N-modes auto [tAgA, tAsA] = tma_partition(tma_atom_A, @@ -312,10 +328,6 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // Barrier Initialization - auto elect_one_thr = cute::elect_one_sync(); - auto elect_one_warp = (threadIdx.x / 32 == 0); - auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; - // Barriers in SMEM should be initialized by a single thread. if (elect_one_warp && elect_one_thr) { // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) @@ -441,6 +453,14 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) } __syncthreads(); // All threads sync with issuing thread } + __syncthreads(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + // Then deallocate TMEM + if (elect_one_warp) { + tmem_allocator.release_allocation_lock(); + tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } template #include +#include #include #include #include @@ -277,34 +278,13 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f) // find and find_if // -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -find_if(T const& t, F&& f, seq) -{ - if constexpr (decltype(f(get(t)))::value) { - return cute::C{}; - } else - if constexpr (sizeof...(Is) == 0) { - return cute::C{}; - } else { - return find_if(t, f, seq{}); - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::find_if(t, f, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return cute::C>{}; }, tuple_seq{}); } else { return cute::C{}; } @@ -326,7 +306,7 @@ auto any_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq{}); } else { return f(t); } @@ -340,7 +320,7 @@ auto all_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq{}); + return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq{}); } else { return f(t); } diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index ba22ef1c..524a47ef 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -31,6 +31,7 @@ #pragma once #include +#include // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index 91589538..2383b4e6 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -72,6 +72,27 @@ # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED #endif @@ -91,8 +112,11 @@ #endif // {add, mul, fma}.f32x2 PTX -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) - #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) + // Enable CuTe MMA Atoms +# define CUTE_ARCH_FFMA2_SM100_ENABLED + // Enable f32x2 PTX generation +# define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif #if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) @@ -109,3 +133,37 @@ # endif #endif +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# endif +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) + #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + diff --git a/include/cute/arch/copy_sm100.hpp b/include/cute/arch/copy_sm100.hpp index 19b13841..aa969afe 100644 --- a/include/cute/arch/copy_sm100.hpp +++ b/include/cute/arch/copy_sm100.hpp @@ -28,10 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - -// - -// #pragma once #include @@ -316,17 +312,14 @@ struct SM100_U8x16_STSM_T } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cute - //////////////////////////////////////////////////////////////////////////////////////////////////// // // UTCCP PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +namespace SM100::TMEM::UTCCP { + // 128 data path lanes, 256-bit pattern, 1cta mode struct SM100_UTCCP_128dp256bit_1cta { @@ -558,21 +551,19 @@ struct SM100_UTCCP_2x64dp128bitlw0123_2cta } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cute +} // end namespace SM100::TMEM::UTCCP //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +namespace SM100::TMEM::LOAD { //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// // -// TMEM_LOAD PTX definitions +// TMEM LOAD PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -3945,7 +3936,6 @@ struct SM100_TMEM_LOAD_32dp32b128x } }; - //////////////////////////////////////////////////////////////////////////////////////////////////// // 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read @@ -4065,9 +4055,21 @@ struct SM100_TMEM_LOAD_32dp32b128x_16b //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::LOAD + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::STORE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// // -// TMEM_STORE PTX definitions +// TMEM STORE PTX definitions // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -4086,8 +4088,8 @@ struct SM100_TMEM_STORE_16dp256b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4110,8 +4112,8 @@ struct SM100_TMEM_STORE_16dp256b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4136,8 +4138,8 @@ struct SM100_TMEM_STORE_16dp256b2x asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4163,8 +4165,8 @@ struct SM100_TMEM_STORE_16dp256b2x_16b asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4194,8 +4196,8 @@ struct SM100_TMEM_STORE_16dp256b4x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4227,8 +4229,8 @@ struct SM100_TMEM_STORE_16dp256b4x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4268,8 +4270,8 @@ struct SM100_TMEM_STORE_16dp256b8x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4313,8 +4315,8 @@ struct SM100_TMEM_STORE_16dp256b8x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4374,8 +4376,8 @@ struct SM100_TMEM_STORE_16dp256b16x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4443,8 +4445,8 @@ struct SM100_TMEM_STORE_16dp256b16x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4544,8 +4546,8 @@ struct SM100_TMEM_STORE_16dp256b32x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -4661,8 +4663,8 @@ struct SM100_TMEM_STORE_16dp256b32x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -4716,8 +4718,8 @@ struct SM100_TMEM_STORE_16dp128b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4740,8 +4742,8 @@ struct SM100_TMEM_STORE_16dp128b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4764,8 +4766,8 @@ struct SM100_TMEM_STORE_16dp128b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4788,8 +4790,8 @@ struct SM100_TMEM_STORE_16dp128b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -4814,8 +4816,8 @@ struct SM100_TMEM_STORE_16dp128b4x asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4841,8 +4843,8 @@ struct SM100_TMEM_STORE_16dp128b4x_16b asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -4872,8 +4874,8 @@ struct SM100_TMEM_STORE_16dp128b8x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4905,8 +4907,8 @@ struct SM100_TMEM_STORE_16dp128b8x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4946,8 +4948,8 @@ struct SM100_TMEM_STORE_16dp128b16x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -4991,8 +4993,8 @@ struct SM100_TMEM_STORE_16dp128b16x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5052,8 +5054,8 @@ struct SM100_TMEM_STORE_16dp128b32x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5121,8 +5123,8 @@ struct SM100_TMEM_STORE_16dp128b32x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5222,8 +5224,8 @@ struct SM100_TMEM_STORE_16dp128b64x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -5339,8 +5341,8 @@ struct SM100_TMEM_STORE_16dp128b64x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -5394,8 +5396,8 @@ struct SM100_TMEM_STORE_16dp64b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5418,8 +5420,8 @@ struct SM100_TMEM_STORE_16dp64b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5442,8 +5444,8 @@ struct SM100_TMEM_STORE_16dp64b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5466,8 +5468,8 @@ struct SM100_TMEM_STORE_16dp64b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5490,8 +5492,8 @@ struct SM100_TMEM_STORE_16dp64b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5514,8 +5516,8 @@ struct SM100_TMEM_STORE_16dp64b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -5540,8 +5542,8 @@ struct SM100_TMEM_STORE_16dp64b8x asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -5567,8 +5569,8 @@ struct SM100_TMEM_STORE_16dp64b8x_16b asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -5598,8 +5600,8 @@ struct SM100_TMEM_STORE_16dp64b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5631,8 +5633,8 @@ struct SM100_TMEM_STORE_16dp64b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5672,8 +5674,8 @@ struct SM100_TMEM_STORE_16dp64b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5717,8 +5719,8 @@ struct SM100_TMEM_STORE_16dp64b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5778,8 +5780,8 @@ struct SM100_TMEM_STORE_16dp64b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5847,8 +5849,8 @@ struct SM100_TMEM_STORE_16dp64b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -5948,8 +5950,8 @@ struct SM100_TMEM_STORE_16dp64b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6065,8 +6067,8 @@ struct SM100_TMEM_STORE_16dp64b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6120,8 +6122,8 @@ struct SM100_TMEM_STORE_16dp32b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32" "[%0] , 1," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6144,8 +6146,8 @@ struct SM100_TMEM_STORE_16dp32b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32" "[%0] , 2," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6168,8 +6170,8 @@ struct SM100_TMEM_STORE_16dp32b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32" "[%0] , 2," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6192,8 +6194,8 @@ struct SM100_TMEM_STORE_16dp32b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32" "[%0] , 4," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6216,8 +6218,8 @@ struct SM100_TMEM_STORE_16dp32b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32" "[%0] , 4," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6240,8 +6242,8 @@ struct SM100_TMEM_STORE_16dp32b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32" "[%0] , 8," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6266,8 +6268,8 @@ struct SM100_TMEM_STORE_16dp32b8x asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32" "[%0] , 8," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -6293,8 +6295,8 @@ struct SM100_TMEM_STORE_16dp32b8x_16b asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32" "[%0] , 16," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -6324,8 +6326,8 @@ struct SM100_TMEM_STORE_16dp32b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6357,8 +6359,8 @@ struct SM100_TMEM_STORE_16dp32b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6398,8 +6400,8 @@ struct SM100_TMEM_STORE_16dp32b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6443,8 +6445,8 @@ struct SM100_TMEM_STORE_16dp32b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6504,8 +6506,8 @@ struct SM100_TMEM_STORE_16dp32b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6573,8 +6575,8 @@ struct SM100_TMEM_STORE_16dp32b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -6674,8 +6676,8 @@ struct SM100_TMEM_STORE_16dp32b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6791,8 +6793,8 @@ struct SM100_TMEM_STORE_16dp32b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -6846,8 +6848,8 @@ struct SM100_TMEM_STORE_32dp32b1x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6870,8 +6872,8 @@ struct SM100_TMEM_STORE_32dp32b1x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32" "[%0]," - "{%1};\n" - : + "{%1};\n" + : : "r"(dst_addr), "r"(src0) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6894,8 +6896,8 @@ struct SM100_TMEM_STORE_32dp32b2x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6918,8 +6920,8 @@ struct SM100_TMEM_STORE_32dp32b2x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32" "[%0]," - "{%1, %2};\n" - : + "{%1, %2};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6942,8 +6944,8 @@ struct SM100_TMEM_STORE_32dp32b4x #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6966,8 +6968,8 @@ struct SM100_TMEM_STORE_32dp32b4x_16b #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32" "[%0]," - "{%1, %2, %3, %4};\n" - : + "{%1, %2, %3, %4};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); #else CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); @@ -6992,8 +6994,8 @@ struct SM100_TMEM_STORE_32dp32b8x asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -7019,8 +7021,8 @@ struct SM100_TMEM_STORE_32dp32b8x_16b asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32" "[%0]," "{%1, %2, %3, %4," - "%5, %6, %7, %8};\n" - : + "%5, %6, %7, %8};\n" + : : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); #else @@ -7050,8 +7052,8 @@ struct SM100_TMEM_STORE_32dp32b16x "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7083,8 +7085,8 @@ struct SM100_TMEM_STORE_32dp32b16x_16b "{%1, %2, %3, %4," "%5, %6, %7, %8," "%9, %10, %11, %12," - "%13, %14, %15, %16};\n" - : + "%13, %14, %15, %16};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7124,8 +7126,8 @@ struct SM100_TMEM_STORE_32dp32b32x "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7169,8 +7171,8 @@ struct SM100_TMEM_STORE_32dp32b32x_16b "%17, %18, %19, %20," "%21, %22, %23, %24," "%25, %26, %27, %28," - "%29, %30, %31, %32};\n" - : + "%29, %30, %31, %32};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7230,8 +7232,8 @@ struct SM100_TMEM_STORE_32dp32b64x "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7299,8 +7301,8 @@ struct SM100_TMEM_STORE_32dp32b64x_16b "%49, %50, %51, %52," "%53, %54, %55, %56," "%57, %58, %59, %60," - "%61, %62, %63, %64};\n" - : + "%61, %62, %63, %64};\n" + : : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), "r"(src04), "r"(src05), "r"(src06), "r"(src07), "r"(src08), "r"(src09), "r"(src10), "r"(src11), @@ -7400,8 +7402,8 @@ struct SM100_TMEM_STORE_32dp32b128x "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -7517,8 +7519,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b "%113, %114, %115, %116," "%117, %118, %119, %120," "%121, %122, %123, %124," - "%125, %126, %127, %128};\n" - : + "%125, %126, %127, %128};\n" + : : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), "r"(src004), "r"(src005), "r"(src006), "r"(src007), "r"(src008), "r"(src009), "r"(src010), "r"(src011), @@ -7561,7 +7563,8 @@ struct SM100_TMEM_STORE_32dp32b128x_16b //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cute +} // namespace SM100::TMEM::STORE //////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace cute diff --git a/include/cute/arch/mma_sm100.hpp b/include/cute/arch/mma_sm100.hpp index 2fa532d2..749da816 100644 --- a/include/cute/arch/mma_sm100.hpp +++ b/include/cute/arch/mma_sm100.hpp @@ -29,7 +29,6 @@ * **************************************************************************************************/ // - // #pragma once @@ -37,6 +36,48 @@ #include #include +#include + namespace cute { +struct SM100_2x1x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float2[1]; + using BRegisters = float[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float2 const& a01, + float const& b0, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, a01, make_float2(b0, b0), c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_2x1x1_F32F32F32F32 without CUTE_ARCH_FLOAT2_MATH_ENABLED"); +#endif + } +}; + +struct SM100_1x2x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float[1]; + using BRegisters = float2[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float const& a0, + float2 const& b01, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, make_float2(a0, a0), b01, c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_1x2x1_F32F32F32F32 without CUTE_ARCH_FFMA2_SM100_ENABLED"); +#endif + } +}; + } // namespace cute diff --git a/include/cute/arch/tmem_allocator_sm100.hpp b/include/cute/arch/tmem_allocator_sm100.hpp index 9839e740..680e237f 100644 --- a/include/cute/arch/tmem_allocator_sm100.hpp +++ b/include/cute/arch/tmem_allocator_sm100.hpp @@ -28,19 +28,34 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// -// #pragma once #include -#include -#include - -#include +#include +#include +#include namespace cute::TMEM { +// +// TMEM Addressing Constants +// + +// 128 DP x 512 COL x uint32_t-addressing +using MAX_CAPACITY_BITS = Int<128*512*32>; + +// TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) +using DP_b = cute::constant; + +// TMEM DP stride in type-T addressing +template +using DP = cute::constant::OffsetShift)>; + +// +// TMEM Allocators +// + // All operations of this class require that only a single warp uniformly participates class Allocator1Sm { public: @@ -77,8 +92,8 @@ public: asm volatile( "{\n\t" "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" - "}" - : + "}" + : : "r"(tmem_ptr), "r"(num_columns)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); @@ -130,7 +145,7 @@ public: } /** - * Frees the TMEM corresponding to the pointer and slice count provided. + * Frees the TMEM corresponding to the pointer and slice count provided. * Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices. * @param tmem_ptr Base address of the TMEM address space being freed. * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. @@ -146,8 +161,8 @@ public: asm volatile( "{\n\t" "tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t" - "}" - : + "}" + : : "r"(tmem_ptr), "r"(num_columns)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index 6a767ae3..594149d4 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -28,13 +28,11 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// - -// #pragma once #include +#include #include #include @@ -230,92 +228,11 @@ struct Copy_Traits using RefLayout = SrcLayout; }; -namespace TMEM { - using MAX_CAPACITY_BITS = Int<128*512*32>; // 128 DP x 512 COL x uint32_t-addressing - - template // TMEM DP stride in type-T addressing - using DP = cute::constant::OffsetShift)>; - - using DP_b = cute::constant; // TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) -} - -// TMEM_LOAD copy_unpack -template -struct TMEM_LOAD_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_tmem::value, "Expected TMEM src."); - static_assert(is_rmem::value, "Expected RMEM dst."); - - using SrcType = typename TS::value_type; - CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), - "Expected src to have the specific TMEM layout required by CopyOp."); - - uint32_t tmem_addr = raw_pointer_cast(src.data()); - - using RegTypeDst = typename remove_extent::type; - Tensor rD = recast(dst); - - constexpr int RegNumDst = extent::value; - CUTE_STATIC_ASSERT_V(size(rD) == Int{}, - "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); - - // thread idx <=> DP lane assert. - // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. -#if defined(__CUDA_ARCH__) && !defined(NDEBUG) - assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); -#endif - - detail::explode(CopyOp::copy, - &tmem_addr, seq<0>{}, - rD, make_seq{}); - } -}; - -// TMEM_STORE copy_unpack -template -struct TMEM_STORE_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected RMEM src."); - static_assert(is_tmem::value, "Expected TMEM dst."); - - using RegTypeSrc = typename remove_extent::type; - Tensor rS = recast(src); - - constexpr int RegNumSrc = extent::value; - CUTE_STATIC_ASSERT_V(size(rS) == Int{}, - "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); - - using DstType = typename TD::value_type; - CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), - "Expected dst to have the specific TMEM layout required by CopyOp."); - - uint32_t tmem_addr = raw_pointer_cast(dst.data()); - - // thread idx <=> DP lane assert. - // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. -#if defined(__CUDA_ARCH__) && !defined(NDEBUG) - assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); -#endif - - detail::explode(CopyOp::copy, - rS, make_seq{}, - &tmem_addr, seq<0>{}); - } -}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM Traits and Utilities +// +//////////////////////////////////////////////////////////////////////////////////////////////////// template struct Copy_Atom; @@ -418,10 +335,2406 @@ make_tmem_warp_partitioner(Tensor const& tmem) return make_tiler_impl(layout_tv, tiler); } -} // end namespace cute +namespace SM100::TMEM::LOAD { + +// +// Specialized copy_unpack implementation for SM100::TMEM::LOAD instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); +} + +} // end namespace SM100::TMEM::LOAD + +namespace SM100::TMEM::STORE { + +// +// Specialized copy_unpack implementation for SM100::TMEM::STORE instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected RMEM src."); + static_assert(is_tmem::value, "Expected TMEM dst."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + + using DstType = typename TD::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected dst to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(dst.data()); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); +} + +} // end namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD Copy Traits +// //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cute { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + // Logical bit id to bit idx (address) + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_STORE Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1159,2183 +3472,38 @@ tmem_load_to_store(CopyOp) { } } -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace TMEM //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMEM_LOAD Copy Traits -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - // Logical thread id to thread idx (warp) - using ThrID = Layout<_32>; - // Logical bit id to bit idx (address) - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_64, _2>>, - Stride,Stride< _1,_2048>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2>>, - Stride,Stride< _1,_2048>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _2>>, - Stride,Stride< _1,_4096,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _2>>, - Stride,Stride< _1,_4096,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _4>>, - Stride,Stride< _1,_8192,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _4>>, - Stride,Stride< _1,_8192,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _8>>, - Stride,Stride< _1,_16384,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _8>>, - Stride,Stride< _1,_16384,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _16>>, - Stride,Stride< _1,_32768,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _16>>, - Stride,Stride< _1,_32768,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _32>>, - Stride,Stride< _1,_65536,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_64, _2, _32>>, - Stride,Stride< _1,_65536,_256>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_1024>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_1024>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _2>>, - Stride,Stride< _1,_2048,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _2>>, - Stride,Stride< _1,_2048,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _4>>, - Stride,Stride< _1,_4096,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _4>>, - Stride,Stride< _1,_4096,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _8>>, - Stride,Stride< _1,_8192,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _8>>, - Stride,Stride< _1,_8192,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _16>>, - Stride,Stride< _1,_16384,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _16>>, - Stride,Stride< _1,_16384,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _32>>, - Stride,Stride< _1,_32768,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _32>>, - Stride,Stride< _1,_32768,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _64>>, - Stride,Stride< _1,_65536,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2, _64>>, - Stride,Stride< _1,_65536,_128>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_32>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_32>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _4>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _4>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _8>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32, _8>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_16>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_16>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_32>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_32>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_64>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_64>>, - Stride,Stride< _1,_64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_128>>, - Stride,Stride< _1, _64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,Shape <_32,_128>>, - Stride,Stride< _1, _64>>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_32>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_32>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_64>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_64>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_128>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_128>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_256>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_256>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_512>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_512>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_1024>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_1024>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_2048>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_2048>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_4096>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _16>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout,_4096>, - Stride, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_32, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_32, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_64, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_64, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_128, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_128, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_256, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_256, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_512, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_512, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_1024, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_1024, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_2048, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_2048, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, - Stride< _1,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_4096, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct Copy_Traits - : TMEM_LOAD_Unpack -{ - using ThrID = Layout<_32>; - using ValID = Layout, _32>, - Stride,TMEM::DP_b>>; - using SrcLayout = Layout, - Stride< _0, _1>>; - using DstLayout = Layout, - Stride<_4096, _1>>; - using RefLayout = SrcLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMEM_STORE Copy Traits -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template <> -struct Copy_Traits - : TMEM_STORE_Unpack -{ - using ThrID = typename Copy_Traits::ThrID; - using ValID = typename Copy_Traits::ValID; - using SrcLayout = typename Copy_Traits::DstLayout; - using DstLayout = typename Copy_Traits::SrcLayout; - using RefLayout = typename Copy_Traits::RefLayout; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////////////////////////////// // // UTCCP Copy Traits // //////////////////////////////////////////////////////////////////////////////////////////////////// +namespace SM100::TMEM::UTCCP { + +// +// Specialized copy_unpack implementation for SM100::TMEM::UTCCP instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + CopyOp::copy(src[0], raw_pointer_cast(dst.data())); +} + +} // end namespace SM100::TMEM::UTCCP + // In the following UTCCP traits, the ValID is representing: // logical_bit_idx -> tmem_addr_offset. // And the logical_bit_idx is numbered in the order of: @@ -3344,132 +3512,77 @@ struct Copy_Traits // The last two modes provide boradcast transformation for 4x32DP and 2x64DP. // With above, the strides of first two modes are neccessary to be TMEM::DP_b and 1. // And the stride of the third mode in the SrcLayout must be zero. + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_1cta; + template <> struct Copy_Traits { using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_2cta; + template <> struct Copy_Traits { using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_1cta; + template <> struct Copy_Traits { using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_2cta; + template <> struct Copy_Traits { using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_128dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_1cta; + template <> struct Copy_Traits { @@ -3485,66 +3598,35 @@ struct Copy_Traits */ using ThrID = Layout<_1>; - // logical bit_idx -> tmem_addr using ValID = Layout, Stride>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_2cta; + template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_0,Stride<_32,_128>>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_1cta; + template <> struct Copy_Traits { @@ -3556,64 +3638,33 @@ struct Copy_Traits // [core_matrix_strided, core_matrix_leading, broadcast] using ValID = Layout, Stride<_DP,_1, _DPx32>>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _0>>>; - - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4x32dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_2cta; + template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_4x32dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_1cta; + template <> struct Copy_Traits { @@ -3625,63 +3676,34 @@ struct Copy_Traits // [core_matrix_strided, core_matrix_leading, broadcast] using ValID = Layout, Stride<_DP,_1, _DPx64>>; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _64, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0213_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_2cta; + template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _64, _0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_1cta; + template <> struct Copy_Traits { @@ -3695,62 +3717,31 @@ struct Copy_Traits using ValID = Layout, Stride<_DP,_1 ,_DPx64,_DPx32>>; - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32,_4096,_0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0123_1cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_2cta; + template <> struct Copy_Traits { - using ThrID = Layout<_2>; - // logical bit_idx -> tmem_addr using ValID = typename Copy_Traits::ValID; - - // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, Stride<_0,Stride<_1, _32, _4096,_0>>>; - // Map from (dst-thr,dst-val) to bit using DstLayout = Layout, Stride<_0,_1>>; - // Reference map from (thr,val) to bit using RefLayout = DstLayout; - - - template - CUTE_HOST_DEVICE friend constexpr - void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); - static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); - SM100_UTCCP_2x64dp128bitlw0123_2cta::copy(src[0], raw_pointer_cast(dst.data())); - } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTE_HOST_DEVICE constexpr @@ -3775,4 +3766,3 @@ make_utccp_copy(CopyOp const&, } // namespace cute -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index beefa63f..e4d1e3ff 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -647,7 +647,7 @@ make_tma_atom_im2col(CopyOp, gtensor_cwhdn, range_c, range_whdn, - detail::get_swizzle_portion(slayout), + get_swizzle_portion(slayout), tma_layout_vt, lower_corner_whd, upper_corner_whd, diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index f336eff2..820dc103 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -37,10 +37,13 @@ #include #include #include -#include // cute::TMEM:: +#include // cute::TMEM:: + #include #include // cute::GMMA:: #include // cute::GMMA:: +#include // UTCCP smem desc + #include // Check that aggregate initialization in .with() initializes all fields @@ -417,6 +420,9 @@ constexpr auto get_utccp_smem_desc_tensor(Tensor const& smem_u namespace UMMA { +// Import TMEM constants +namespace TMEM = cute::TMEM; + enum class TmemAllocMode { // Default allocation mode. // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be @@ -3053,7 +3059,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); - + using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 187b7e41..e3dd6d27 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -51,8 +51,8 @@ // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // -// Standard-layout types preserve ABI across host-device boundaries. -// They are safe to use as device kernel parameters. +// Standard-layout types preserve ABI across host-device boundaries. They are safe to use as device kernel parameters. +// The standard-layout requirement prevents a more common EBO-based implemented of cute::tuple. // // The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of // the conversion SFINAE, special overloading, and avoiding cvref template types. @@ -62,12 +62,15 @@ namespace cute { -namespace detail +template +struct tuple; + +namespace eso { // ESO stands for "empty structure optimization." -// We use this technique to ensure that cute::tuple -// doesn't waste space storing template arguments that have no data (like integral_constant). +// We use this technique to ensure that cute::tuple doesn't waste space +// storing template arguments that have no data (like integral_constant). // Empty types in the template argument list are not even constructed, // and do not have unique element addresses. Calling `get` // constructs and returns an instance of an empty type on demand. @@ -131,94 +134,92 @@ struct ESO { }; // Get Nth value from ESO -template +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t>> -getv(ESO const&) -{ - return {}; -} - -template -CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> const&> -getv(ESO const& s) +R +getr(S&& s) noexcept { if constexpr (N == 0) { - return static_cast(s.first_); + return static_cast(s).first_; } else { - return getv(s.rest_); + return getr(static_cast(s).rest_); } - CUTE_GCC_UNREACHABLE; } -template +// Compilers disagree on decltype(auto), so these implementations avoid it at cost +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> &> -getv(ESO& s) +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> const&> +getv_cr(ESO const& s) noexcept { - if constexpr (N == 0) { - return static_cast(s.first_); + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return getv(s.rest_); + return getr> const&, N>(s); } - CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr -cute::enable_if_t>>::value, - cute::tuple_element_t> &&> -getv(ESO&& s) +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &> +getv_r(ESO& s) noexcept { - if constexpr (N == 0) { - return static_cast(s.first_); + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return getv(static_cast&&>(s.rest_)); + return getr> &, N>(s); } - CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr -auto -findt(ESO const& t) noexcept +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &&> +getv_rr(ESO&& s) noexcept { - if constexpr (cute::is_same_v) { - return C{}; - } else - if constexpr (sizeof...(Rest) == 0) { - return C{}; - } else - if constexpr (IsRestEmpty) { - return cute::detail::findt(ESO_t{}); + if constexpr (cute::is_empty>>::value) { + return {}; } else { - return cute::detail::findt(t.rest_); + return getr> &&, N>(static_cast&&>(s)); } + CUTE_GCC_UNREACHABLE; } -} // end namespace detail +} // end namespace eso template -struct tuple : detail::ESO_t +struct tuple : eso::ESO_t { CUTE_HOST_DEVICE constexpr tuple() {} CUTE_HOST_DEVICE constexpr - tuple(T const&... t) : detail::ESO_t(t...) {} + tuple(T const&... t) : eso::ESO_t(t...) {} }; template <> struct tuple<> {}; +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + // Returns the element in the ith position of the tuple template CUTE_HOST_DEVICE constexpr @@ -226,7 +227,7 @@ decltype(auto) get(tuple const& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); + return eso::getv_cr(t); } template @@ -235,7 +236,7 @@ decltype(auto) get(tuple& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); + return eso::getv_r(t); } template @@ -244,22 +245,22 @@ decltype(auto) get(tuple&& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(static_cast&&>(t)); + return eso::getv_rr(static_cast&&>(t)); } -// Returns the position of type X (as a static integer) in the tuple -// type's argument list. X must be unique in the argument list. +// Returns the first position of type X (as a static integer) in the tuple +// type's argument list. template CUTE_HOST_DEVICE constexpr auto -find(tuple const& t) noexcept +find(tuple const&) noexcept { - return detail::findt(t); + return cute::C...>>{}; } // // Custom is_tuple trait simply checks the existence of tuple_size -// and assumes std::get(.), std::tuple_element +// and assumes get(.), tuple_element // namespace detail { @@ -273,19 +274,7 @@ template struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; template -constexpr bool is_tuple_v = cute::is_tuple::value; - -// -// make_tuple (value-based implementation) -// - -template -CUTE_HOST_DEVICE constexpr -tuple -make_tuple(T const&... t) -{ - return {t...}; -} +static constexpr bool is_tuple_v = cute::is_tuple::value; // // tuple_cat concatenates multiple cute::tuple into a single cute::tuple, diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index b8ac5f0d..dfffbe25 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -31,6 +31,7 @@ #pragma once #include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE +#include namespace cute { @@ -39,11 +40,35 @@ template struct type_list {}; // get for type_list -// requires tuple_element_t> to have std::is_default_constructible +// Get an instance of the Ith type in the pack T... +// Requires tuple_element_t> to have std::is_default_constructible template CUTE_HOST_DEVICE constexpr CUTE_STL_NAMESPACE::tuple_element_t> -get(type_list const& t) noexcept { +get(type_list const&) noexcept { + return {}; +} + +// Find the index of the first true in the pack B... +template +struct find_true { + CUTE_HOST_DEVICE static constexpr size_t find() { + size_t i = 0; + (void) ((B ? true : (++i, false)) || ...); + return i; + } + static constexpr size_t value = find(); +}; + +template +static constexpr size_t find_true_v = find_true::value; + +// find for type_list +// Finds the first position of type X (as a static integer) in the T... pack +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::integral_constant...>> +find(type_list const&) noexcept { return {}; } @@ -69,9 +94,8 @@ struct tuple_size> template struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; + : CUTE_STL_NAMESPACE::tuple_element> +{}; } // end namespace std @@ -94,9 +118,8 @@ struct tuple_size> template struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; + : CUTE_STL_NAMESPACE::tuple_element> +{}; } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 97eafa7a..3f02a41d 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -834,7 +834,7 @@ coalesce_x(Layout const& layout) } else { return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); } - + CUTE_GCC_UNREACHABLE; } @@ -1030,7 +1030,7 @@ template CUTE_HOST_DEVICE constexpr auto -composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, +composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_stride, RShape const& rhs_shape, RStride const& rhs_stride) { if constexpr (is_tuple::value) { // Right-distributivity of Layout composition for RHS tuple @@ -1067,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, auto rest_stride = get<3>(init); auto curr_shape = get(lhs_shape); - auto curr_stride = get(lhs_stride); + [[maybe_unused]] auto curr_stride = get(lhs_stride); // Strong divisibility condition -- requires composition to be statically verifiable. //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 43d3c4b2..ef1ca18e 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -128,8 +128,6 @@ make_fragment_like(ComposedLayout,Offset,Layout> const& layout) // Utilities // -namespace detail { - // Get just the Swizzle part of a composed layout. template CUTE_HOST_DEVICE constexpr @@ -167,8 +165,6 @@ get_nonswizzle_portion(Layout const& slayout) return slayout; } -} // namespace detail - // // Slice a Swizzled ComposedLayout // diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index c634e884..c9c636a0 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -42,7 +42,7 @@ namespace cutlass { namespace arch { constexpr int sm100_smem_capacity_bytes = 232448; -constexpr int sm120_smem_capacity_bytes = 102400; +constexpr int sm120_smem_capacity_bytes = 101376; #if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index d7036baf..3d5ec10b 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -50,6 +50,9 @@ #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#define CUTLASS_ARCH_TCGEN_ENABLED 1 +#endif namespace cutlass { /// @brief diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index 1dd27f78..e5daf829 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -92,6 +92,14 @@ #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 #endif + // SM100f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM100F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) && CUDA_ARCH_FAMILY(1000)) + #define CUTLASS_ARCH_MMA_SM100F_ENABLED CUTLASS_ARCH_MMA_SM100F_SUPPORTED + #endif #endif #endif @@ -109,6 +117,14 @@ #define CUTLASS_ARCH_MMA_SM101A_ENABLED 1 #endif + // SM101f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM101F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) && CUDA_ARCH_FAMILY(1010)) + #define CUTLASS_ARCH_MMA_SM101F_ENABLED CUTLASS_ARCH_MMA_SM101F_SUPPORTED + #endif #endif #endif @@ -124,12 +140,21 @@ #define CUTLASS_ARCH_MMA_SM120A_ENABLED 1 #endif + // SM120f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM120F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) && CUDA_ARCH_FAMILY(1200)) + #define CUTLASS_ARCH_MMA_SM120F_ENABLED CUTLASS_ARCH_MMA_SM120F_SUPPORTED + #endif #endif #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) # define CUTLASS_ARCH_CLC_ENABLED #endif diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h index ae66de27..e7defb5d 100644 --- a/include/cutlass/arch/grid_dependency_control.h +++ b/include/cutlass/arch/grid_dependency_control.h @@ -53,6 +53,20 @@ #endif #endif +#ifndef CUTLASS_GDC_ENABLED + #if(CUDA_BARRIER_ENABLED && \ + defined(CUTLASS_ENABLE_GDC_FOR_SM100) && \ + defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 1000 &&\ + (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ + (__CUDA_ARCH__ == 1010 &&\ + (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ + (__CUDA_ARCH__ == 1200 &&\ + (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))))) + #define CUTLASS_GDC_ENABLED + #endif +#endif + namespace cutlass { namespace arch { @@ -84,6 +98,5 @@ static constexpr bool IsGdcGloballyEnabled = true; static constexpr bool IsGdcGloballyEnabled = false; #endif - } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index 557643e5..a65ee328 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -47,6 +47,14 @@ #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ + (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ + || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ + || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ + ) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif + #endif namespace cutlass { diff --git a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl index db1f7dae..9a9d4cb4 100644 --- a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl @@ -168,7 +168,7 @@ private: // Calculate SMEM matrix A and B buffers' pipeline stages static constexpr uint32_t AccumulatorPipelineStageCount = 2; - static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr uint32_t SchedulerPipelineStageCount = 1; static constexpr uint32_t CLCResponseSize = 16; // AccumulatorPipeline = PipelineUmmaAsync @@ -179,8 +179,6 @@ private: static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); // CLC (scheduler) response static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * CLCResponseSize; - // CLC Throttle pipeline storage - static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); // Tmem ptr storage @@ -190,7 +188,6 @@ private: CLCPipelineStorage + LoadOrderBarrierStorage + TmemDeallocStorage + - CLCThrottlePipelineStorage + CLCResponseStorage + TmemBasePtrsStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. @@ -204,7 +201,12 @@ private: constexpr static int NumSpatialDimensions = detail::gmem_layout_tags_to_spatial_dims(); using DispatchPolicy = cutlass::conv::MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK>; + ConvOp, + PipelineStages, + NumSpatialDimensions, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>; public: using CollectiveOp = cutlass::conv::collective::CollectiveConv< diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index dc75b988..278f69f9 100644 --- a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -28,9 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// -// #pragma once @@ -66,6 +64,8 @@ template < conv::Operator ConvOp, int Stages, int NumSpatialDims, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) class ElementA_, @@ -75,7 +75,12 @@ template < class TileTraitsB_> struct CollectiveConv< MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape>, + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, TileShapeMNKL_, ElementA_, ElementB_, @@ -87,7 +92,12 @@ struct CollectiveConv< // Type Aliases // using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< - ConvOp, Stages, NumSpatialDims, ClusterShape>; + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) using ElementA = ElementA_; using ElementB = ElementB_; @@ -348,10 +358,12 @@ public: // Constructor // CUTLASS_DEVICE - CollectiveConv(Params const& params) { + CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { if constexpr (IsDynamicCluster) { - dim3 cs = cute::cluster_shape(); - const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; } @@ -648,28 +660,14 @@ public: } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE static void - prefetch_tma_descriptors(Params const& mainloop_params) { - if constexpr (IsDynamicCluster) { - dim3 cs = cute::cluster_shape(); - const bool is_fallback_cluster = (cs.x == mainloop_params.cluster_shape_fallback.x && cs.y == mainloop_params.cluster_shape_fallback.y); - if (is_fallback_cluster) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a_fallback.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b_fallback.get_tma_descriptor()); - } - else { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - } - else { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE static auto partition_accumulator_shape() { auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) @@ -794,11 +792,10 @@ public: Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); - Layout cta_layout_mnk = make_layout(cluster_shape); + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); - int block_rank_in_cluster = cute::block_rank_in_cluster(); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); // Project the cta_layout for tma_a along the n-modes auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_, @@ -890,7 +887,7 @@ public: } CUTLASS_DEVICE auto - mma_init(TensorStorage& shared_tensors) { + mma_init(TensorStorage& shared_tensors) const { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -909,6 +906,9 @@ private: typename Params::TMA_A const* observed_tma_load_a_ = nullptr; typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index b4bf8a53..d569cb1c 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -86,7 +86,10 @@ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { // SM100 tensor op kernel schedule -struct KernelImplicitTmaWarpSpecializedSm100 { }; +struct KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = 0; + static constexpr int AccumulatorPipelineStageCount = 0; +}; // Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100 // but for opting into 1 or 2 SM atoms @@ -96,11 +99,23 @@ struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializ struct KernelStridedDgradTmaWs1SmSm100 { }; struct KernelStridedDgradTmaWs2SmSm100 { }; +// Policy for implicit gemm kernel +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop template< conv::Operator ConvOp_, int Stages_, int NumSpatialDimensions_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> > struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { @@ -109,7 +124,7 @@ struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { static constexpr Operator ConvOp = ConvOp_; using ClusterShape = ClusterShape_; using ArchTag = arch::Sm100; - using Schedule = KernelImplicitTmaWarpSpecializedSm100; + using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100; static_assert(NumSpatialDimensions >= 1); }; diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index 90236e1f..0874d8f8 100644 --- a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -29,8 +29,6 @@ * **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -110,7 +108,8 @@ public: static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; // TileID scheduler // CLC pipeline depth determines how many waves (stages-1) the scheduler can race ahead - static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; using TileSchedulerTag = TileSchedulerTag_; using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< @@ -135,7 +134,6 @@ public: static constexpr uint32_t NumFixupBarriers = 1; // Pipelines and pipeline states - static constexpr uint32_t AccumulatorPipelineStageCount = SchedulerPipelineStageCount; static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); // Pipeline and pipeline state types @@ -157,10 +155,6 @@ public: using CLCPipelineState = cutlass::PipelineDetail::PipelineCLCFetchAsyncPipelineState; using CLCPipelineSharedStorage = cutlass::PipelineDetail::PipelineCLCFetchAsyncSharedStorage; - using CLCThrottlePipeline = cutlass::PipelineAsync; - using CLCThrottlePipelineState = cutlass::PipelineDetail::PipelineAsyncPipelineState; - using CLCThrottlePipelineSharedStorage = cutlass::PipelineDetail::PipelineAsyncSharedStorage; - using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; @@ -172,14 +166,12 @@ public: using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; using CLCPipelineStorage = CLCPipelineSharedStorage; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; - using CLCThrottlePipelineStorage = CLCThrottlePipelineSharedStorage; alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) LoadOrderBarrierStorage load_order; alignas(16) CLCPipelineStorage clc; alignas(16) AccumulatorPipelineStorage accumulator; - alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; } pipelines; @@ -193,7 +185,6 @@ public: EpilogueTensorStorage epilogue; MainloopTensorStorage mainloop; } tensors; - }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -207,7 +198,7 @@ public: KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; - + // Kernel device entry point API struct Params { using ProblemShapeMNKL = decltype(CollectiveMainloop::get_problem_shape_MNKL(ProblemShape{})); @@ -398,7 +389,7 @@ public: : WarpCategory::Epilogue; uint32_t lane_predicate = cute::elect_one_sync(); - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); int cluster_size = size(cluster_shape); uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; @@ -407,24 +398,23 @@ public: constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_category == WarpCategory::Sched) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - } - if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + collective_mainloop.prefetch_tma_descriptors(); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + collective_epilogue.prefetch_tma_descriptors(params.epilogue); + } + // Do we load source tensor C or other aux inputs bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); - IsParticipant is_participant = { (warp_category == WarpCategory::MMA), // mma (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched @@ -462,7 +452,7 @@ public: epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - epi_load_pipeline_params.initializing_warp = 4; + epi_load_pipeline_params.initializing_warp = 1; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -474,7 +464,7 @@ public: typename LoadOrderBarrier::Params load_order_barrier_params; load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; load_order_barrier_params.group_size = NumMainloopLoadThreads; - load_order_barrier_params.initializing_warp = 5; + load_order_barrier_params.initializing_warp = 3; LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); // CLC pipeline @@ -493,7 +483,7 @@ public: clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; } clc_pipeline_params.transaction_bytes = CLCResponseSize; - clc_pipeline_params.initializing_warp = 1; + clc_pipeline_params.initializing_warp = 4; CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); // Mainloop-Epilogue pipeline @@ -507,29 +497,13 @@ public: // Only one producer thread arrives on this barrier. accumulator_pipeline_params.producer_arv_count = 1; accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; - accumulator_pipeline_params.initializing_warp = 2; + accumulator_pipeline_params.initializing_warp = 5; AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape, cute::true_type{}, // Perform barrier init cute::false_type{}); // Delay mask calculation - // CLC throttle pipeline - typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; - if (WarpCategory::MainloopLoad == warp_category) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; - } - if (WarpCategory::Sched == warp_category) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; - } - clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - clc_throttle_pipeline_params.dst_blockid = 0; - clc_throttle_pipeline_params.initializing_warp = 3; - CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); - CLCThrottlePipelineState clc_pipe_throttle_consumer_state; - CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); - // Tmem allocator TmemAllocator tmem_allocator{}; @@ -544,12 +518,10 @@ public: // We need this to guarantee that the Pipeline init is visible // To all producers and consumer threadblocks in the cluster - if (cluster_size > 1) { - cute::cluster_arrive_relaxed(); - } - else { - __syncthreads(); - } + pipeline_init_arrive_relaxed(cluster_size); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); uint32_t tmem_stage_ptrs[AccumulatorPipelineStageCount]; MainloopPipelineState mainloop_pipe_consumer_state; @@ -571,7 +543,7 @@ public: // Calculate mask after cluster barrier arrival mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); - accumulator_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, problem_shape_MNKL, TileShape{}, block_id_in_cluster); @@ -583,58 +555,13 @@ public: int TmemColumnsPerAccumulatorTile = cutlass::detail::find_tmem_tensor_col_offset(accumulators); pipeline_init_wait(cluster_size); - if (is_participant.sched) { - - // Whether a new CLC query must be performed. - // See comment below where this variable is updated for a description of - // why this variable is needed. - bool requires_clc_query = true; - - do { - if (requires_clc_query) { - // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. - clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); - clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); - ++clc_pipe_throttle_consumer_state; - - // Query next clcID and update producer state - clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); - } - - // Fetch next work tile - auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( - work_tile_info, - clc_pipeline, - clc_pipe_consumer_state - ); - - // Only perform a new CLC query if we consumed a new CLC query result in - // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does - // not consume a new CLC query response is when processing stream-K units. - // The current stream-K scheduler uses single WorkTileInfo to track multiple - // (potentially-partial) tiles to be computed via stream-K. In this case, - // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, - // rather than consuming a CLC query response. - requires_clc_query = increment_pipe; - if (increment_pipe) { - ++clc_pipe_consumer_state; - } - - work_tile_info = next_work_tile_info; - } while (work_tile_info.is_valid()); - clc_pipeline.producer_tail(clc_pipe_producer_state); - } - else if (is_participant.main_load) { - + if (is_participant.main_load) { // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); bool do_load_order_arrive = is_epi_load_needed; - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); Tensor gA_mk = get<0>(load_inputs); - bool requires_clc_query = true; do { // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. @@ -642,12 +569,6 @@ public: auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); - if (is_first_cta_in_cluster && requires_clc_query) { - clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); - clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); - ++clc_pipe_throttle_producer_state; - } - auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( params.mainloop, mainloop_pipeline, @@ -683,7 +604,6 @@ public: ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); - requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } @@ -691,60 +611,43 @@ public: collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); } - else if (is_participant.epi_load) { - // Ensure that the prefetched kernel does not touch - // unflushed global memory prior to this instruction - cutlass::arch::wait_on_dependent_grids(); + else if (is_participant.sched) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; - bool do_load_order_wait = true; - bool do_tail_load = false; do { - bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } - // Get current work tile and fetch next work tile + // Fetch next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( work_tile_info, clc_pipeline, clc_pipe_consumer_state ); - work_tile_info = next_work_tile_info; + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } - if (compute_epilogue) { - - if (do_load_order_wait) { - load_order_barrier.wait(); - do_load_order_wait = false; - } - - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - CtaShape_MNK{}, - cta_coord_mnkl, - TileShape{}, - TiledMma{}, - shared_storage.tensors.epilogue - ); - - do_tail_load = true; - } - - // Calculate the cta coordinates of the next work tile - cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + work_tile_info = next_work_tile_info; } while (work_tile_info.is_valid()); - - if (do_tail_load) { - collective_epilogue.load_tail( - epi_load_pipeline, epi_load_pipe_producer_state, - epi_store_pipeline, epi_store_pipe_producer_state); - } + clc_pipeline.producer_tail(clc_pipe_producer_state); } + else if (is_participant.mma) { // Tmem allocation sequence tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); @@ -757,6 +660,7 @@ public: tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; } auto mma_inputs = collective_mainloop.mma_init(shared_storage.tensors.mainloop); + do { auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); @@ -788,7 +692,6 @@ public: mma_inputs, k_tile_count ); - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); } ++accumulator_pipe_producer_state; @@ -802,6 +705,7 @@ public: // Release the right to allocate before deallocations so that the next CTA can rasterize tmem_allocator.release_allocation_lock(); + // Leader MMA waits for leader + peer epilogues to release accumulator stage if (is_mma_leader_cta) { accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); @@ -816,8 +720,66 @@ public: // Free entire tmem allocation tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } + + else if (is_participant.epi_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + else if (is_participant.epilogue) { // Wait for tmem allocate here tmem_allocation_result_barrier.arrive_and_wait(); @@ -875,13 +837,16 @@ public: epi_load_pipe_consumer_state = load_state_next; epi_store_pipe_producer_state = store_state_next; accumulator_pipe_consumer_state = acc_state_next; - do_tail_store = true; } work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); } while (work_tile_info.is_valid()); + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). if (do_tail_store) { collective_epilogue.store_tail( epi_load_pipeline, epi_load_pipe_consumer_state, @@ -889,19 +854,8 @@ public: CtaShape_MNK{}); } } - } -private: - - // Synchronization call. Blocks until barriers are initialized in shared memory. - CUTLASS_DEVICE - void - pipeline_init_wait(int cluster_size) { - if (cluster_size > 1) { - cute::cluster_wait(); - } else { - __syncthreads(); } } }; diff --git a/include/cutlass/detail/sm100_blockwise_scale_layout.hpp b/include/cutlass/detail/blockwise_scale_layout.hpp similarity index 67% rename from include/cutlass/detail/sm100_blockwise_scale_layout.hpp rename to include/cutlass/detail/blockwise_scale_layout.hpp index 8f75bd25..2d545bbd 100644 --- a/include/cutlass/detail/sm100_blockwise_scale_layout.hpp +++ b/include/cutlass/detail/blockwise_scale_layout.hpp @@ -179,11 +179,110 @@ struct Sm100BlockwiseScaleConfig { }; +template +struct RuntimeBlockwiseScaleConfig { + + using ShapeSFA = Shape, Shape, int32_t>; + using ShapeSFB = Shape, Shape, int32_t>; + + using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutSFA = Layout; + using LayoutSFB = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSFA{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSFB{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, sfm))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfm, cute::ceil_div(M, sfm)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + + if constexpr (majorSFB == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, sfn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto nk_layout = make_layout( + make_shape(make_shape(sfn, cute::ceil_div(N, sfn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); + } + +}; + +// Sm90 only supports MN major for SFA and SFB for now +template +using Sm90BlockwiseScaleConfig = Sm100BlockwiseScaleConfig; + template constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { return Sm100BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; } +template +constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::detail diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 44a83b9d..cf9b803b 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -208,6 +208,35 @@ namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// +// __CUDA_ARCH_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_CONDITIONAL) + +#if defined(__CUDA_ARCH_SPECIFIC__) +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (__CUDA_ARCH_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (false) +#endif + +#endif + +// __CUDA_ARCH_FAMILY_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_FAMILY) + +#if defined(__CUDA_ARCH_FAMILY_SPECIFIC__) +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (__CUDA_ARCH_FAMILY_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (false) +#endif + +#endif + +#if !defined(CUDA_ARCH_CONDITIONAL_OR_FAMILY) +#define CUDA_ARCH_CONDITIONAL_OR_FAMILY(ARCH_XXYY) \ + (CUDA_ARCH_CONDITIONAL(ARCH_XXYY) || CUDA_ARCH_FAMILY(ARCH_XXYY)) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + }; // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index a0a183b0..562adc65 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -33,10 +33,10 @@ #include "cute/layout.hpp" #include "cute/pointer_sparse.hpp" // cute::is_sparse #include "cute/swizzle.hpp" // cute::Swizzle -#include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion +#include "cute/swizzle_layout.hpp" // cute::get_swizzle_portion #include "cute/util/type_traits.hpp" #include "cute/arch/copy_sm90_tma.hpp" -#include "cute/arch/copy_sm100_tma.hpp" +#include "cute/arch/copy_sm100_tma.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" @@ -219,8 +219,8 @@ stride_to_layout_tag_A() { return layout::ColumnMajor{}; } // Specialize for sparse layout - else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && - cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && cute::is_same_v(InternalStrideA{}))>>) { return layout::ColumnMajor{}; } @@ -308,8 +308,8 @@ constexpr bool is_tma_copy_engine() { || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } @@ -349,7 +349,7 @@ get_alignment_count_from_gmem_tiled_copy() { cutlass::gemm::collective::detail::is_sm10x_f8f6f4_element() && cute::is_same_v::type, uint8_t>) { return 128; } - + // For sparse MMA, alignment in logical elements is increased by sparsity factor if constexpr (cute::is_sparse_v) { return 128 / sizeof_bits::value * ElementMma::sparsity; @@ -366,7 +366,7 @@ get_alignment_count_from_gmem_tiled_copy() { // Return alignment bit requirements for the GEMM inputs. template < class ElementType - , bool IsF8F6F4SubBytes=false + , bool IsF8F6F4SubBytes=false > constexpr int get_input_alignment_bits() { @@ -383,12 +383,12 @@ get_input_alignment_bits() { template constexpr int get_output_alignment_bits() { - + if constexpr (sizeof_bits::value == 6) { // U6 format : The inner tensor size dimension must be a multiple of 96B. return 96 * 8; } - + return 128; } @@ -424,7 +424,7 @@ template CUTLASS_HOST_DEVICE constexpr size_t alignment_for_swizzle(Layout layout) { - return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout)); + return alignment_for_swizzle(cute::get_swizzle_portion(layout)); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/builders/sm120_builder.inl b/include/cutlass/epilogue/collective/builders/sm120_builder.inl index ad1f44a0..e1c1bff8 100644 --- a/include/cutlass/epilogue/collective/builders/sm120_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm120_builder.inl @@ -63,13 +63,27 @@ struct EpilogueSFVecSize> static constexpr int value = FusionOp::SFVecSize; }; +// Helper to deduce NumEpilogueWarpGroups based on Schedule +template +struct GetNumEpilogueWarpGroups { + static constexpr int value = 2; +}; + +template +struct GetNumEpilogueWarpGroups> { + static constexpr int value = Schedule::NumEpilogueWarpGroups; +}; + // Returns the parameterized dispatch policy for the TMA epilogue -template +template constexpr auto sm120_get_tma_dispatch_policy() { using namespace cute; constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + using StrideD = cutlass::detail::TagToStrideC_t; + using InternalStrideD = cute::remove_pointer_t; + constexpr bool IsGroupedGemmKernel = !cute::is_same_v; // For 120, a FragmentSize of 4 is used to match the // output per thread from each MMA. Epilogue subtiles iterate over multiple of these @@ -86,9 +100,17 @@ sm120_get_tma_dispatch_policy() { // SM120 epilogues use smaller stage counts in order to fit within the limited shared memory capacity. constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 2), StagesD+1) - : StagesD; - - return Sm120TmaWarpSpecialized{}; + : StagesD; + + constexpr int NumEpilogueWarpGroups = GetNumEpilogueWarpGroups::value; + + if constexpr (IsGroupedGemmKernel) { + return Sm120PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm120TmaWarpSpecialized{}; + } } // Returns the smem layout atom to be used for C or D matrix @@ -291,6 +313,9 @@ struct Sm120TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + using CopyOpS2G = cute::conditional_t, SM90_TMA_STORE_IM2COL, @@ -306,15 +331,15 @@ struct Sm120TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; - using SmemLayoutAtomC = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); - using SmemLayoutAtomD = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); + using SmemLayoutAtomC = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); + using SmemLayoutAtomD = decltype(detail::sm120_get_epilogue_smem_swizzle_layout_atom()); - using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); + using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); - using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); + using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); // Get register to register tiled copy that happen before shared memory store. - using CopyOpR2R = decltype(detail::sm120_get_register_transform_op()); + using CopyOpR2R = decltype(detail::sm120_get_register_transform_op()); // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination @@ -334,8 +359,32 @@ struct Sm120TmaBuilderImpl { constexpr static bool ReuseSmemC = DispatchPolicy::ReuseSmemC; constexpr static bool DelayTmaStore = DispatchPolicy::DelayTmaStore; + //Helper to deduce BaseDispatchPolicy based on DispatchPolicy + template + struct GetBaseDispatchPolicy { + using Type = T; + }; + + template + struct GetBaseDispatchPolicy> { + using Type = typename cutlass::epilogue::Sm90PtrArrayTmaWarpSpecialized; + }; + + template + struct GetBaseDispatchPolicy> { + using Type = typename cutlass::epilogue::Sm90TmaWarpSpecialized; + }; + + using BaseDispatchPolicy = typename GetBaseDispatchPolicy::Type; + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< - Sm90TmaWarpSpecialized, + BaseDispatchPolicy, TileShape_MNK, EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal @@ -394,13 +443,15 @@ struct CollectiveBuilder< cute::enable_if_t || cute::is_same_v || cute::is_same_v || + cute::is_same_v || + cute::is_same_v || cute::is_same_v >> { private: using EpilogueTile_MN = decltype(detail::sm120_compute_tile_shape_or_override, FusionOperation>()); using DispatchPolicy = - decltype(detail::sm120_get_tma_dispatch_policy, Schedule>()); + decltype(detail::sm120_get_tma_dispatch_policy()); public: diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index b7bd6f40..0d019b1c 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/epilogue/collective/detail.hpp" @@ -225,22 +226,27 @@ public: return; } + using FragCType = remove_cvref_t; + using FragDType = remove_cvref_t; + // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), residue_tCcD)) { - tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); - } + FragCType fragC; + bool pred = elem_less(tCcD(i), residue_tCcD); + arch::global_load(fragC, &tCgC(i), pred); + FragDType fragD = epilogue_op(accumulators(i), fragC); + arch::global_store(fragD, &tCgD(i), pred); } } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), residue_tCcD)) { - tCgD(i) = epilogue_op(accumulators(i)); - } + bool pred = elem_less(tCcD(i), residue_tCcD); + FragDType fragD = epilogue_op(accumulators(i)); + arch::global_store(fragD, &tCgD(i), pred); } } } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 2759d0c6..2c72c301 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -124,6 +124,23 @@ struct sm90_is_ptr_array_tma_dispatch_policy< NumEpilogueWarpGroups>> : cute::true_type {}; +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm120PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + template static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 77a3b510..0ed7d6b9 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -144,7 +144,6 @@ private: static_assert(StagesD >= 1, "StagesD must be >= 1"); constexpr static bool ReuseSmemC = ReuseSmemC_ && is_destination_supported; - constexpr static bool DelayTmaStore = DelayTmaStore_; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -172,6 +171,12 @@ private: constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -860,10 +865,12 @@ public: synchronize(); } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -1215,10 +1222,12 @@ public: } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -1478,16 +1487,23 @@ public: tensormaps_cp_fence_release( TensorMapStorage& shared_tensormap, cute::TmaDescriptor const* tensormap) { + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if (is_source_supported) { - if (cute::elect_one_sync()) { - cute::tma_desc_commit_group(); - cute::tma_desc_wait_group(); - } + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); } } else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); } } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index 2eb5c582..c2172798 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -462,6 +462,10 @@ private: || is_same_v; // alloc reduction buffer for custom EVTs constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + public: constexpr static int ThreadCount = 128; constexpr static uint32_t TmaTransactionBytes = 0; @@ -669,10 +673,12 @@ public: static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<4>(tTR_tAcc); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<3>(tTR_tAcc); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_last_iteration = iter_m == size<3>(tTR_tAcc)-1 && iter_n == size<4>(tTR_tAcc)-1; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp index 89e5448c..3f445bf5 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -140,7 +140,6 @@ private: static_assert(StagesD >= 1, "StagesD must be >= 1"); constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_m_major_C = detail::is_m_major(); @@ -172,6 +171,12 @@ private: constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -808,10 +813,12 @@ public: ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; @@ -1162,10 +1169,12 @@ public: } // For each epilogue subtile within the CTA tile - CUTLASS_PRAGMA_UNROLL - for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { - CUTLASS_PRAGMA_UNROLL - for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { int epi_m = iter_m, epi_n = iter_n; bool is_first_iteration = iter_m == 0 && iter_n == 0; bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index b27ec712..41c95f16 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -41,6 +41,7 @@ #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" #include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" @@ -304,8 +305,9 @@ public: // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. // These will be replaced with correct values before the initial tma load. auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); - auto init_M = get<0>(init_shape); - auto init_N = get<1>(init_shape); + constexpr int tma_alignment_bits = 128; + auto init_M = tma_alignment_bits; + auto init_N = tma_alignment_bits; auto init_L = get<3>(init_shape); static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); @@ -761,7 +763,14 @@ public: CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + // Get TiledCopy for partition reference when consumer store. TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps @@ -784,6 +793,12 @@ public: bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; @@ -894,17 +909,41 @@ public: ++load_wait_state; } - int mma_m = epi_m; - int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; - // Vectorized fragment loop with visitor callback entry point - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD_frg); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { - tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( + tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); + } + } } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + } + // The latest we can delay the TMA store is right before the smem store of the next iteration // since the current TMA store needs to be committed before we can acquire the next smem buffer if constexpr (DelayTmaStore) { @@ -918,7 +957,7 @@ public: // Smem reduction callback entry point using current store buffer for workspace cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), - synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); // Copy tile from register to regiser if needed if constexpr (IsUseR2R) { @@ -930,6 +969,11 @@ public: copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + // Copy tile from register to smem if constexpr (is_destination_supported) { copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); @@ -1140,7 +1184,6 @@ public: ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch, int32_t warp_group_idx) { - if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); @@ -1161,14 +1204,24 @@ public: TensorMapStorage& shared_tensormaps, cute::TmaDescriptor const* tensormap, const int32_t warp_group_idx = 0) { - + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if constexpr (is_source_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); } } else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); } } diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 870be4c2..2e6213fe 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -255,6 +255,23 @@ struct Sm120TmaWarpSpecialized { constexpr static bool DelayTmaStore = DelayTmaStore_; }; +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ +> +struct Sm120PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp index 8f391aac..b769b1f0 100644 --- a/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp @@ -1317,6 +1317,277 @@ struct FusionCallbacks< using Impl::Impl; }; +// Sm120 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// For Ptr-Array and Grouped GEMM +// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinearCombRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinearCombinationPtrArray< ElementCompute, ElementCompute, + ElementSource, ElementScalar, RoundStyle + > // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +// For Ptr-Array and Grouped GEMM +// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombEltActRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp index 59a9d030..e72e971b 100644 --- a/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp @@ -94,6 +94,8 @@ struct Sm120BlockScaleFactorRowStore { using Params = Arguments; + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { @@ -390,21 +392,21 @@ struct Sm120BlockScaleFactorRowStore { } ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - ElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); tC_rSFD_flt(coord) = qpvscale; // // Apply the scale factor to the output // ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { + if constexpr (cute::is_same_v) { // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); } else { // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); } }(); @@ -458,15 +460,24 @@ struct Sm120BlockScaleFactorRowStore { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(params_ptr->ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); @@ -537,6 +548,8 @@ struct Sm120BlockScaleFactorColStore { }; using Params = Arguments; + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { @@ -770,21 +783,21 @@ struct Sm120BlockScaleFactorColStore { synchronize(); ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - ElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); filter(tC_rSFD)(sf_id + mma_in_epi*ColsPerThreadAccFrag) = qpvscale; // // Apply the scale factor to the output // ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { + if constexpr (cute::is_same_v) { // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); } else { // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); } }(); @@ -829,18 +842,27 @@ struct Sm120BlockScaleFactorColStore { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } static_assert(size<0>(EpilogueTile{}) && ((size<0>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(params_ptr->ptr_scale_factor), + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 33c5585f..c3abfdff 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -52,6 +52,18 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +// If kIsHeavy is a member, use it. Otherwise, assume that it's false. +template +struct kIsHeavy_member_or_false { + static constexpr bool value = false; +}; +template +struct kIsHeavy_member_or_false::type> { + static constexpr bool value = Op::kIsHeavy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // Identity operator template struct Identity { @@ -113,6 +125,8 @@ template